.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/forced_alignment_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_forced_alignment_tutorial.py: Forced Alignment with Wav2Vec2 ============================== **Author**: `Moto Hira `__ .. warning:: Starting with version 2.8, we are refactoring TorchAudio to transition it into a maintenance phase. As a result: - The APIs described in this tutorial are deprecated in 2.8 and will be removed in 2.9. - The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. This tutorial shows how to align transcript to speech with ``torchaudio``, using CTC segmentation algorithm described in `CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition `__. .. note:: This tutorial was originally written to illustrate a usecase for Wav2Vec2 pretrained model. TorchAudio now has a set of APIs designed for forced alignment. The `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__ illustrates the usage of :py:func:`torchaudio.functional.forced_align`, which is the core API. If you are looking to align your corpus, we recommend to use :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which combines :py:func:`~torchaudio.functional.forced_align` and other support functions with pre-trained model specifically trained for forced-alignment. Please refer to the `Forced alignment for multilingual data `__ which illustrates its usage. .. GENERATED FROM PYTHON SOURCE LINES 42-54 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.8.0+cu126 2.8.0 cuda .. GENERATED FROM PYTHON SOURCE LINES 55-68 Overview -------- The process of alignment looks like the following. 1. Estimate the frame-wise label probability from audio waveform 2. Generate the trellis matrix which represents the probability of labels aligned at time step. 3. Find the most likely path from the trellis matrix. In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for acoustic feature extraction. .. GENERATED FROM PYTHON SOURCE LINES 71-76 Preparation ----------- First we import the necessary packages, and fetch data that we work on. .. GENERATED FROM PYTHON SOURCE LINES 76-87 .. code-block:: default from dataclasses import dataclass import IPython import matplotlib.pyplot as plt torch.random.manual_seed(0) SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/examples/tutorials/forced_alignment_tutorial.py:84: UserWarning: torchaudio.utils.download.download_asset has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") .. GENERATED FROM PYTHON SOURCE LINES 88-104 Generate frame-wise label probability ------------------------------------- The first step is to generate the label class porbability of each audio frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`. ``torchaudio`` provides easy access to pretrained models with associated labels. .. note:: In the subsequent sections, we will compute the probability in log-domain to avoid numerical instability. For this purpose, we normalize the ``emission`` with :py:func:`torch.log_softmax`. .. GENERATED FROM PYTHON SOURCE LINES 104-117 .. code-block:: default bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H model = bundle.get_model().to(device) labels = bundle.get_labels() with torch.inference_mode(): waveform, _ = torchaudio.load(SPEECH_FILE) emissions, _ = model(waveform.to(device)) emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() print(labels) .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/src/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder. warnings.warn( /pytorch/audio/src/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. s = torchaudio.io.StreamReader(src, format, None, buffer_size) ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') .. GENERATED FROM PYTHON SOURCE LINES 118-120 Visualization ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 120-134 .. code-block:: default def plot(): fig, ax = plt.subplots() img = ax.imshow(emission.T) ax.set_title("Frame-wise class probability") ax.set_xlabel("Time") ax.set_ylabel("Labels") fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") fig.tight_layout() plot() .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_tutorial_001.png :alt: Frame-wise class probability :srcset: /tutorials/images/sphx_glr_forced_alignment_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 135-171 Generate alignment probability (trellis) ---------------------------------------- From the emission matrix, next we generate the trellis which represents the probability of transcript labels occur at each time frame. Trellis is 2D matrix with time axis and label axis. The label axis represents the transcript that we are aligning. In the following, we use :math:`t` to denote the index in time axis and :math:`j` to denote the index in label axis. :math:`c_j` represents the label at label index :math:`j`. To generate, the probability of time step :math:`t+1`, we look at the trellis from time step :math:`t` and emission at time step :math:`t+1`. There are two path to reach to time step :math:`t+1` with label :math:`c_{j+1}`. The first one is the case where the label was :math:`c_{j+1}` at :math:`t` and there was no label change from :math:`t` to :math:`t+1`. The other case is where the label was :math:`c_j` at :math:`t` and it transitioned to the next label :math:`c_{j+1}` at :math:`t+1`. The follwoing diagram illustrates this transition. .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png Since we are looking for the most likely transitions, we take the more likely path for the value of :math:`k_{(t+1, j+1)}`, that is :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )` where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)` represents the probability of label :math:`c_j` at time step :math:`t`. :math:`repeat` represents the blank token from CTC formulation. (For the detail of CTC algorithm, please refer to the *Sequence Modeling with CTC* [`distill.pub `__]) .. GENERATED FROM PYTHON SOURCE LINES 171-202 .. code-block:: default # We enclose the transcript with space tokens, which represent SOS and EOS. transcript = "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|" dictionary = {c: i for i, c in enumerate(labels)} tokens = [dictionary[c] for c in transcript] print(list(zip(transcript, tokens))) def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) trellis = torch.zeros((num_frame, num_tokens)) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) trellis[0, 1:] = -float("inf") trellis[-num_tokens + 1 :, 0] = float("inf") for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token trellis[t, :-1] + emission[t, tokens[1:]], ) return trellis trellis = get_trellis(emission, tokens) .. rst-class:: sphx-glr-script-out .. code-block:: none [('|', 1), ('I', 7), ('|', 1), ('H', 8), ('A', 4), ('D', 11), ('|', 1), ('T', 3), ('H', 8), ('A', 4), ('T', 3), ('|', 1), ('C', 16), ('U', 13), ('R', 10), ('I', 7), ('O', 5), ('S', 9), ('I', 7), ('T', 3), ('Y', 19), ('|', 1), ('B', 21), ('E', 2), ('S', 9), ('I', 7), ('D', 11), ('E', 2), ('|', 1), ('M', 14), ('E', 2), ('|', 1), ('A', 4), ('T', 3), ('|', 1), ('T', 3), ('H', 8), ('I', 7), ('S', 9), ('|', 1), ('M', 14), ('O', 5), ('M', 14), ('E', 2), ('N', 6), ('T', 3), ('|', 1)] .. GENERATED FROM PYTHON SOURCE LINES 203-205 Visualization ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 205-218 .. code-block:: default def plot(): fig, ax = plt.subplots() img = ax.imshow(trellis.T, origin="lower") ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") fig.tight_layout() plot() .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_tutorial_002.png :alt: forced alignment tutorial :srcset: /tutorials/images/sphx_glr_forced_alignment_tutorial_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 219-222 In the above visualization, we can see that there is a trace of high probability crossing the matrix diagonally. .. GENERATED FROM PYTHON SOURCE LINES 225-244 Find the most likely path (backtracking) ---------------------------------------- Once the trellis is generated, we will traverse it following the elements with high probability. We will start from the last label index with the time step of highest probability, then, we traverse back in time, picking stay (:math:`c_j \rightarrow c_j`) or transition (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition probability :math:`k_{t, j} p(t+1, c_{j+1})` or :math:`k_{t, j+1} p(t+1, repeat)`. Transition is done once the label reaches the beginning. The trellis matrix is used for path-finding, but for the final probability of each segment, we take the frame-wise probability from emission matrix. .. GENERATED FROM PYTHON SOURCE LINES 244-294 .. code-block:: default @dataclass class Point: token_index: int time_index: int score: float def backtrack(trellis, emission, tokens, blank_id=0): t, j = trellis.size(0) - 1, trellis.size(1) - 1 path = [Point(j, t, emission[t, blank_id].exp().item())] while j > 0: # Should not happen but just in case assert t > 0 # 1. Figure out if the current position was stay or change # Frame-wise score of stay vs change p_stay = emission[t - 1, blank_id] p_change = emission[t - 1, tokens[j]] # Context-aware score for stay vs change stayed = trellis[t - 1, j] + p_stay changed = trellis[t - 1, j - 1] + p_change # Update position t -= 1 if changed > stayed: j -= 1 # Store the path with frame-wise probability. prob = (p_change if changed > stayed else p_stay).exp().item() path.append(Point(j, t, prob)) # Now j == 0, which means, it reached the SoS. # Fill up the rest for the sake of visualization while t > 0: prob = emission[t - 1, blank_id].exp().item() path.append(Point(j, t - 1, prob)) t -= 1 return path[::-1] path = backtrack(trellis, emission, tokens) for p in path: print(p) .. rst-class:: sphx-glr-script-out .. code-block:: none Point(token_index=0, time_index=0, score=0.9999996423721313) Point(token_index=0, time_index=1, score=0.9999996423721313) Point(token_index=0, time_index=2, score=0.9999996423721313) Point(token_index=0, time_index=3, score=0.9999996423721313) Point(token_index=0, time_index=4, score=0.9999996423721313) Point(token_index=0, time_index=5, score=0.9999996423721313) Point(token_index=0, time_index=6, score=0.9999996423721313) Point(token_index=0, time_index=7, score=0.9999996423721313) Point(token_index=0, time_index=8, score=0.9999998807907104) Point(token_index=0, time_index=9, score=0.9999996423721313) Point(token_index=0, time_index=10, score=0.9999996423721313) Point(token_index=0, time_index=11, score=0.9999998807907104) Point(token_index=0, time_index=12, score=0.9999996423721313) Point(token_index=0, time_index=13, score=0.9999996423721313) Point(token_index=0, time_index=14, score=0.9999996423721313) Point(token_index=0, time_index=15, score=0.9999996423721313) Point(token_index=0, time_index=16, score=0.9999996423721313) Point(token_index=0, time_index=17, score=0.9999996423721313) Point(token_index=0, time_index=18, score=0.9999998807907104) Point(token_index=0, time_index=19, score=0.9999996423721313) Point(token_index=0, time_index=20, score=0.9999996423721313) Point(token_index=0, time_index=21, score=0.9999996423721313) Point(token_index=0, time_index=22, score=0.9999996423721313) Point(token_index=0, time_index=23, score=0.9999997615814209) Point(token_index=0, time_index=24, score=0.9999998807907104) Point(token_index=0, time_index=25, score=0.9999998807907104) Point(token_index=0, time_index=26, score=0.9999998807907104) Point(token_index=0, time_index=27, score=0.9999998807907104) Point(token_index=0, time_index=28, score=0.9999985694885254) Point(token_index=0, time_index=29, score=0.9999943971633911) Point(token_index=0, time_index=30, score=0.9999842643737793) Point(token_index=1, time_index=31, score=0.9846633672714233) Point(token_index=1, time_index=32, score=0.9999706745147705) Point(token_index=1, time_index=33, score=0.15361329913139343) Point(token_index=1, time_index=34, score=0.9999172687530518) Point(token_index=2, time_index=35, score=0.6085338592529297) Point(token_index=2, time_index=36, score=0.9997722506523132) Point(token_index=3, time_index=37, score=0.999714195728302) Point(token_index=3, time_index=38, score=0.9999358654022217) Point(token_index=4, time_index=39, score=0.9861765503883362) Point(token_index=4, time_index=40, score=0.9241003394126892) Point(token_index=5, time_index=41, score=0.9259175062179565) Point(token_index=5, time_index=42, score=0.015625856816768646) Point(token_index=5, time_index=43, score=0.9998376369476318) Point(token_index=6, time_index=44, score=0.9988446235656738) Point(token_index=7, time_index=45, score=0.10159187018871307) Point(token_index=7, time_index=46, score=0.9999426603317261) Point(token_index=8, time_index=47, score=0.9999943971633911) Point(token_index=8, time_index=48, score=0.9979596138000488) Point(token_index=9, time_index=49, score=0.036005258560180664) Point(token_index=9, time_index=50, score=0.06174848601222038) Point(token_index=9, time_index=51, score=4.3377182009862736e-05) Point(token_index=10, time_index=52, score=0.9999799728393555) Point(token_index=11, time_index=53, score=0.9967069029808044) Point(token_index=11, time_index=54, score=0.9999257326126099) Point(token_index=11, time_index=55, score=0.9999982118606567) Point(token_index=12, time_index=56, score=0.999064028263092) Point(token_index=12, time_index=57, score=0.9999996423721313) Point(token_index=12, time_index=58, score=0.9999996423721313) Point(token_index=12, time_index=59, score=0.8450244069099426) Point(token_index=12, time_index=60, score=0.9999996423721313) Point(token_index=13, time_index=61, score=0.9996004700660706) Point(token_index=13, time_index=62, score=0.999998927116394) Point(token_index=14, time_index=63, score=0.003534560324624181) Point(token_index=14, time_index=64, score=1.0) Point(token_index=14, time_index=65, score=1.0) Point(token_index=14, time_index=66, score=0.9999915361404419) Point(token_index=15, time_index=67, score=0.9971501231193542) Point(token_index=15, time_index=68, score=0.9999990463256836) Point(token_index=15, time_index=69, score=0.9999992847442627) Point(token_index=15, time_index=70, score=0.9999997615814209) Point(token_index=15, time_index=71, score=0.9999998807907104) Point(token_index=15, time_index=72, score=0.9999880790710449) Point(token_index=15, time_index=73, score=0.011419214308261871) Point(token_index=15, time_index=74, score=0.9999977350234985) Point(token_index=16, time_index=75, score=0.9996128678321838) Point(token_index=16, time_index=76, score=0.999998927116394) Point(token_index=16, time_index=77, score=0.9729052782058716) Point(token_index=16, time_index=78, score=0.999998927116394) Point(token_index=17, time_index=79, score=0.9949370622634888) Point(token_index=17, time_index=80, score=0.999998927116394) Point(token_index=17, time_index=81, score=0.9999125003814697) Point(token_index=17, time_index=82, score=0.9999774694442749) Point(token_index=18, time_index=83, score=0.6575068831443787) Point(token_index=18, time_index=84, score=0.9984293580055237) Point(token_index=18, time_index=85, score=0.9999876022338867) Point(token_index=19, time_index=86, score=0.9993751645088196) Point(token_index=19, time_index=87, score=0.9999988079071045) Point(token_index=19, time_index=88, score=0.10444644093513489) Point(token_index=19, time_index=89, score=0.9999969005584717) Point(token_index=20, time_index=90, score=0.3971569836139679) Point(token_index=20, time_index=91, score=0.9999932050704956) Point(token_index=21, time_index=92, score=1.696418848950998e-06) Point(token_index=21, time_index=93, score=0.9860828518867493) Point(token_index=21, time_index=94, score=0.9999960660934448) Point(token_index=22, time_index=95, score=0.9992727637290955) Point(token_index=22, time_index=96, score=0.9993411898612976) Point(token_index=22, time_index=97, score=0.9999983310699463) Point(token_index=23, time_index=98, score=0.9999971389770508) Point(token_index=23, time_index=99, score=0.9999998807907104) Point(token_index=23, time_index=100, score=0.9999995231628418) Point(token_index=23, time_index=101, score=0.9999732971191406) Point(token_index=24, time_index=102, score=0.9983185529708862) Point(token_index=24, time_index=103, score=0.9999991655349731) Point(token_index=24, time_index=104, score=0.9999996423721313) Point(token_index=24, time_index=105, score=0.9999998807907104) Point(token_index=24, time_index=106, score=1.0) Point(token_index=24, time_index=107, score=0.999862551689148) Point(token_index=24, time_index=108, score=0.9999980926513672) Point(token_index=25, time_index=109, score=0.9988523721694946) Point(token_index=25, time_index=110, score=0.9999798536300659) Point(token_index=26, time_index=111, score=0.8575892448425293) Point(token_index=26, time_index=112, score=0.9999847412109375) Point(token_index=27, time_index=113, score=0.9870195388793945) Point(token_index=27, time_index=114, score=1.9013934434042312e-05) Point(token_index=27, time_index=115, score=0.9999796152114868) Point(token_index=28, time_index=116, score=0.9998248219490051) Point(token_index=28, time_index=117, score=0.9999990463256836) Point(token_index=29, time_index=118, score=0.9999732971191406) Point(token_index=29, time_index=119, score=0.0008990301284939051) Point(token_index=29, time_index=120, score=0.9993460774421692) Point(token_index=30, time_index=121, score=0.9975366592407227) Point(token_index=30, time_index=122, score=0.000303806911688298) Point(token_index=30, time_index=123, score=0.9999344348907471) Point(token_index=31, time_index=124, score=6.088280770200072e-06) Point(token_index=31, time_index=125, score=0.9833387136459351) Point(token_index=32, time_index=126, score=0.9974588751792908) Point(token_index=33, time_index=127, score=0.0008261289913207293) Point(token_index=33, time_index=128, score=0.9965143799781799) Point(token_index=34, time_index=129, score=0.017432427033782005) Point(token_index=34, time_index=130, score=0.9989168643951416) Point(token_index=35, time_index=131, score=0.9999697208404541) Point(token_index=36, time_index=132, score=0.9999842643737793) Point(token_index=36, time_index=133, score=0.9997639060020447) Point(token_index=37, time_index=134, score=0.5108126401901245) Point(token_index=37, time_index=135, score=0.9998301267623901) Point(token_index=38, time_index=136, score=0.08521389961242676) Point(token_index=38, time_index=137, score=0.004068504087626934) Point(token_index=38, time_index=138, score=0.9999815225601196) Point(token_index=39, time_index=139, score=0.011991607956588268) Point(token_index=39, time_index=140, score=0.9999980926513672) Point(token_index=39, time_index=141, score=0.0005900734104216099) Point(token_index=39, time_index=142, score=0.9999085664749146) Point(token_index=40, time_index=143, score=0.9999960660934448) Point(token_index=40, time_index=144, score=0.9999980926513672) Point(token_index=40, time_index=145, score=0.9999916553497314) Point(token_index=41, time_index=146, score=0.9971139430999756) Point(token_index=41, time_index=147, score=0.9981796741485596) Point(token_index=41, time_index=148, score=0.9999310970306396) Point(token_index=42, time_index=149, score=0.9879123568534851) Point(token_index=42, time_index=150, score=0.9997639060020447) Point(token_index=42, time_index=151, score=0.9999536275863647) Point(token_index=43, time_index=152, score=0.9999715089797974) Point(token_index=44, time_index=153, score=0.32006949186325073) Point(token_index=44, time_index=154, score=0.9997830986976624) Point(token_index=45, time_index=155, score=0.01606338657438755) Point(token_index=45, time_index=156, score=0.999901294708252) Point(token_index=46, time_index=157, score=0.46633121371269226) Point(token_index=46, time_index=158, score=0.9999994039535522) Point(token_index=46, time_index=159, score=0.9999996423721313) Point(token_index=46, time_index=160, score=0.9999995231628418) Point(token_index=46, time_index=161, score=0.9999996423721313) Point(token_index=46, time_index=162, score=0.9999996423721313) Point(token_index=46, time_index=163, score=0.9999996423721313) Point(token_index=46, time_index=164, score=0.9999995231628418) Point(token_index=46, time_index=165, score=0.9999995231628418) Point(token_index=46, time_index=166, score=0.9999996423721313) Point(token_index=46, time_index=167, score=0.9999996423721313) Point(token_index=46, time_index=168, score=0.9999995231628418) .. GENERATED FROM PYTHON SOURCE LINES 295-297 Visualization ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 297-311 .. code-block:: default def plot_trellis_with_path(trellis, path): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() for _, p in enumerate(path): trellis_with_path[p.time_index, p.token_index] = float("nan") plt.imshow(trellis_with_path.T, origin="lower") plt.title("The path found by backtracking") plt.tight_layout() plot_trellis_with_path(trellis, path) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_tutorial_003.png :alt: The path found by backtracking :srcset: /tutorials/images/sphx_glr_forced_alignment_tutorial_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 312-313 Looking good. .. GENERATED FROM PYTHON SOURCE LINES 315-323 Segment the path ---------------- Now this path contains repetations for the same labels, so let’s merge them to make it close to the original transcript. When merging the multiple path points, we simply take the average probability for the merged segments. .. GENERATED FROM PYTHON SOURCE LINES 323-365 .. code-block:: default # Merge the labels @dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments segments = merge_repeats(path) for seg in segments: print(seg) .. rst-class:: sphx-glr-script-out .. code-block:: none | (1.00): [ 0, 31) I (0.78): [ 31, 35) | (0.80): [ 35, 37) H (1.00): [ 37, 39) A (0.96): [ 39, 41) D (0.65): [ 41, 44) | (1.00): [ 44, 45) T (0.55): [ 45, 47) H (1.00): [ 47, 49) A (0.03): [ 49, 52) T (1.00): [ 52, 53) | (1.00): [ 53, 56) C (0.97): [ 56, 61) U (1.00): [ 61, 63) R (0.75): [ 63, 67) I (0.88): [ 67, 75) O (0.99): [ 75, 79) S (1.00): [ 79, 83) I (0.89): [ 83, 86) T (0.78): [ 86, 90) Y (0.70): [ 90, 92) | (0.66): [ 92, 95) B (1.00): [ 95, 98) E (1.00): [ 98, 102) S (1.00): [ 102, 109) I (1.00): [ 109, 111) D (0.93): [ 111, 113) E (0.66): [ 113, 116) | (1.00): [ 116, 118) M (0.67): [ 118, 121) E (0.67): [ 121, 124) | (0.49): [ 124, 126) A (1.00): [ 126, 127) T (0.50): [ 127, 129) | (0.51): [ 129, 131) T (1.00): [ 131, 132) H (1.00): [ 132, 134) I (0.76): [ 134, 136) S (0.36): [ 136, 139) | (0.50): [ 139, 143) M (1.00): [ 143, 146) O (1.00): [ 146, 149) M (1.00): [ 149, 152) E (1.00): [ 152, 153) N (0.66): [ 153, 155) T (0.51): [ 155, 157) | (0.96): [ 157, 169) .. GENERATED FROM PYTHON SOURCE LINES 366-368 Visualization ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 368-413 .. code-block:: default def plot_trellis_with_segments(trellis, segments, transcript): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True) ax1.set_title("Path, label and probability for each label") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") ax2.set_title("Label probability with and without repetation") xs, hs, ws = [], [], [] for seg in segments: if seg.label != "|": xs.append((seg.end + seg.start) / 2 + 0.4) hs.append(seg.score) ws.append(seg.end - seg.start) ax2.annotate(seg.label, (seg.start + 0.8, -0.07)) ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") xs, hs = [], [] for p in path: label = transcript[p.token_index] if label != "|": xs.append(p.time_index + 1) hs.append(p.score) ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.axhline(0, color="black") ax2.grid(True, axis="y") ax2.set_ylim(-0.1, 1.1) fig.tight_layout() plot_trellis_with_segments(trellis, segments, transcript) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_tutorial_004.png :alt: Path, label and probability for each label, Label probability with and without repetation :srcset: /tutorials/images/sphx_glr_forced_alignment_tutorial_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 414-415 Looks good. .. GENERATED FROM PYTHON SOURCE LINES 417-426 Merge the segments into words ----------------------------- Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` as the word boundary, so we merge the segments before each occurance of ``'|'``. Then, finally, we segment the original audio into segmented audio and listen to them to see if the segmentation is correct. .. GENERATED FROM PYTHON SOURCE LINES 426-450 .. code-block:: default # Merge words def merge_words(segments, separator="|"): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words word_segments = merge_words(segments) for word in word_segments: print(word) .. rst-class:: sphx-glr-script-out .. code-block:: none I (0.78): [ 31, 35) HAD (0.84): [ 37, 44) THAT (0.52): [ 45, 53) CURIOSITY (0.89): [ 56, 92) BESIDE (0.94): [ 95, 116) ME (0.67): [ 118, 124) AT (0.66): [ 126, 129) THIS (0.70): [ 131, 139) MOMENT (0.88): [ 143, 157) .. GENERATED FROM PYTHON SOURCE LINES 451-453 Visualization ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 453-499 .. code-block:: default def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate): trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1) ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.set_facecolor("lightgray") ax1.set_xticks([]) ax1.set_yticks([]) for word in word_segments: ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") # The original waveform ratio = waveform.size(0) / sample_rate / trellis.size(0) ax2.specgram(waveform, Fs=sample_rate) for word in word_segments: x0 = ratio * word.start x1 = ratio * word.end ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) ax2.set_xlabel("time [second]") ax2.set_yticks([]) fig.tight_layout() plot_alignments( trellis, segments, word_segments, waveform[0], ) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_tutorial_005.png :alt: forced alignment tutorial :srcset: /tutorials/images/sphx_glr_forced_alignment_tutorial_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 500-503 Audio Samples ------------- .. GENERATED FROM PYTHON SOURCE LINES 503-515 .. code-block:: default def display_segment(i): ratio = waveform.size(1) / trellis.size(0) word = word_segments[i] x0 = int(ratio * word.start) x1 = int(ratio * word.end) print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate) .. GENERATED FROM PYTHON SOURCE LINES 517-523 .. code-block:: default # Generate the audio for each segment print(transcript) IPython.display.Audio(SPEECH_FILE) .. rst-class:: sphx-glr-script-out .. code-block:: none |I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT| .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 525-528 .. code-block:: default display_segment(0) .. rst-class:: sphx-glr-script-out .. code-block:: none I (0.78): 0.624 - 0.704 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 530-533 .. code-block:: default display_segment(1) .. rst-class:: sphx-glr-script-out .. code-block:: none HAD (0.84): 0.744 - 0.885 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 535-538 .. code-block:: default display_segment(2) .. rst-class:: sphx-glr-script-out .. code-block:: none THAT (0.52): 0.905 - 1.066 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 540-543 .. code-block:: default display_segment(3) .. rst-class:: sphx-glr-script-out .. code-block:: none CURIOSITY (0.89): 1.127 - 1.851 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 545-548 .. code-block:: default display_segment(4) .. rst-class:: sphx-glr-script-out .. code-block:: none BESIDE (0.94): 1.911 - 2.334 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 550-553 .. code-block:: default display_segment(5) .. rst-class:: sphx-glr-script-out .. code-block:: none ME (0.67): 2.374 - 2.495 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 555-558 .. code-block:: default display_segment(6) .. rst-class:: sphx-glr-script-out .. code-block:: none AT (0.66): 2.535 - 2.595 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 560-563 .. code-block:: default display_segment(7) .. rst-class:: sphx-glr-script-out .. code-block:: none THIS (0.70): 2.635 - 2.796 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 565-568 .. code-block:: default display_segment(8) .. rst-class:: sphx-glr-script-out .. code-block:: none MOMENT (0.88): 2.877 - 3.159 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 569-575 Conclusion ---------- In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to perform CTC segmentation for forced alignment. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.665 seconds) .. _sphx_glr_download_tutorials_forced_alignment_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: forced_alignment_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: forced_alignment_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_