.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/asr_inference_with_ctc_decoder_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_asr_inference_with_ctc_decoder_tutorial.py: ASR Inference with CTC Decoder ============================== **Author**: `Caroline Chen `__ This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss. .. GENERATED FROM PYTHON SOURCE LINES 15-41 Overview -------- Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated. The underlying implementation is ported from `Flashlight `__'s beam search decoder. A mathematical formula for the decoder optimization can be found in the `Wav2Letter paper `__, and a more detailed algorithm can be found in this `blog `__. Running ASR inference using a CTC Beam Search decoder with a KenLM language model and lexicon constraint requires the following components - Acoustic Model: model predicting phonetics from audio waveforms - Tokens: the possible predicted tokens from the acoustic model - Lexicon: mapping between possible words and their corresponding tokens sequence - KenLM: n-gram language model trained with the `KenLM library `__ .. GENERATED FROM PYTHON SOURCE LINES 44-50 Preparation ----------- First we import the necessary utilities and fetch the data that we are working with .. GENERATED FROM PYTHON SOURCE LINES 50-57 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.12.1 0.12.1 .. GENERATED FROM PYTHON SOURCE LINES 59-68 .. code-block:: default import time from typing import List import IPython import matplotlib.pyplot as plt from torchaudio.models.decoder import ctc_decoder from torchaudio.utils import download_asset .. GENERATED FROM PYTHON SOURCE LINES 69-79 Acoustic Model and Data ~~~~~~~~~~~~~~~~~~~~~~~ We use the pretrained `Wav2Vec 2.0 `__ Base model that is finetuned on 10 min of the `LibriSpeech dataset `__, which can be loaded in using :py:func:`torchaudio.pipelines`. For more detail on running Wav2Vec 2.0 speech recognition pipelines in torchaudio, please refer to `this tutorial `__. .. GENERATED FROM PYTHON SOURCE LINES 79-84 .. code-block:: default bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M acoustic_model = bundle.get_model() .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth 0%| | 0.00/360M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 94-100 The transcript corresponding to this audio file is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 100-107 .. code-block:: default waveform, sample_rate = torchaudio.load(speech_file) if sample_rate != bundle.sample_rate: waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) .. GENERATED FROM PYTHON SOURCE LINES 108-116 Files and Data for Decoder ~~~~~~~~~~~~~~~~~~~~~~~~~~ Next, we load in our token, lexicon, and KenLM data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files. .. GENERATED FROM PYTHON SOURCE LINES 119-136 Tokens ^^^^^^ The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index. .. code-block:: # tokens.txt _ | e t ... .. GENERATED FROM PYTHON SOURCE LINES 136-141 .. code-block:: default tokens = [label.lower() for label in bundle.get_labels()] print(tokens) .. rst-class:: sphx-glr-script-out .. code-block:: none ['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z'] .. GENERATED FROM PYTHON SOURCE LINES 142-159 Lexicon ^^^^^^^ The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens. .. code-block:: # lexcion.txt a a | able a b l e | about a b o u t | ... ... .. GENERATED FROM PYTHON SOURCE LINES 162-173 KenLM ^^^^^ This is an n-gram language model trained with the `KenLM library `__. Both the ``.arpa`` or the binarized ``.bin`` LM can be used, but the binary format is recommended for faster loading. The language model used in this tutorial is a 4-gram KenLM trained using `LibriSpeech `__. .. GENERATED FROM PYTHON SOURCE LINES 176-185 Downloading Pretrained Files ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Pretrained files for the LibriSpeech dataset can be downloaded using :py:func:`download_pretrained_files `. Note: this cell may take a couple of minutes to run, as the language model can be large .. GENERATED FROM PYTHON SOURCE LINES 185-193 .. code-block:: default from torchaudio.models.decoder import download_pretrained_files files = download_pretrained_files("librispeech-4-gram") print(files) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/4.97M [00:00`. In addition to the previously mentioned components, it also takes in various beam search decoding parameters and token/word parameters. This decoder can also be run without a language model by passing in `None` into the `lm` parameter. .. GENERATED FROM PYTHON SOURCE LINES 212-227 .. code-block:: default LM_WEIGHT = 3.23 WORD_SCORE = -0.26 beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, nbest=3, beam_size=1500, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) .. GENERATED FROM PYTHON SOURCE LINES 228-233 Greedy Decoder ~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 233-259 .. code-block:: default class GreedyCTCDecoder(torch.nn.Module): def __init__(self, labels, blank=0): super().__init__() self.labels = labels self.blank = blank def forward(self, emission: torch.Tensor) -> List[str]: """Given a sequence emission over labels, get the best path Args: emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: List[str]: The resulting transcript """ indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = torch.unique_consecutive(indices, dim=-1) indices = [i for i in indices if i != self.blank] joined = "".join([self.labels[i] for i in indices]) return joined.replace("|", " ").strip().split() greedy_decoder = GreedyCTCDecoder(tokens) .. GENERATED FROM PYTHON SOURCE LINES 260-274 Run Inference ------------- Now that we have the data, acoustic model, and decoder, we can perform inference. The output of the beam search decoder is of type :py:func:`torchaudio.models.decoder.CTCHypothesis`, consisting of the predicted token IDs, corresponding words, hypothesis score, and timesteps corresponding to the token IDs. Recall the transcript corresponding to the waveform is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 274-281 .. code-block:: default actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said" actual_transcript = actual_transcript.split() emission, _ = acoustic_model(waveform) .. GENERATED FROM PYTHON SOURCE LINES 282-284 The greedy decoder give the following result. .. GENERATED FROM PYTHON SOURCE LINES 284-293 .. code-block:: default greedy_result = greedy_decoder(emission[0]) greedy_transcript = " ".join(greedy_result) greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript) print(f"Transcript: {greedy_transcript}") print(f"WER: {greedy_wer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid WER: 0.38095238095238093 .. GENERATED FROM PYTHON SOURCE LINES 294-296 Using the beam search decoder: .. GENERATED FROM PYTHON SOURCE LINES 296-307 .. code-block:: default beam_search_result = beam_search_decoder(emission) beam_search_transcript = " ".join(beam_search_result[0][0].words).strip() beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len( actual_transcript ) print(f"Transcript: {beam_search_transcript}") print(f"WER: {beam_search_wer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said WER: 0.047619047619047616 .. GENERATED FROM PYTHON SOURCE LINES 308-313 We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”. .. GENERATED FROM PYTHON SOURCE LINES 316-321 Timestep Alignments ------------------- Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs. .. GENERATED FROM PYTHON SOURCE LINES 321-329 .. code-block:: default timesteps = beam_search_result[0][0].timesteps predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens) print(predicted_tokens, len(predicted_tokens)) print(timesteps, timesteps.shape[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none ['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99 tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58, 64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107, 111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157, 159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191, 193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250, 251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289, 292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331, 353], dtype=torch.int32) 99 .. GENERATED FROM PYTHON SOURCE LINES 330-332 Below, we visualize the token timestep alignments relative to the original waveform. .. GENERATED FROM PYTHON SOURCE LINES 332-360 .. code-block:: default def plot_alignments(waveform, emission, tokens, timesteps): fig, ax = plt.subplots(figsize=(32, 10)) ax.plot(waveform) ratio = waveform.shape[0] / emission.shape[1] word_start = 0 for i in range(len(tokens)): if i != 0 and tokens[i - 1] == "|": word_start = timesteps[i] if tokens[i] != "|": plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) elif i != 0: word_end = timesteps[i] ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") xticks = ax.get_xticks() plt.xticks(xticks, xticks / bundle.sample_rate) ax.set_xlabel("time (sec)") ax.set_xlim(0, waveform.shape[0]) plot_alignments(waveform[0], emission, predicted_tokens, timesteps) .. image-sg:: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :alt: asr inference with ctc decoder tutorial :srcset: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 361-369 Beam Search Decoder Parameters ------------------------------ In this section, we go a little bit more in depth about some different parameters and tradeoffs. For the full list of customizable parameters, please refer to the :py:func:`documentation `. .. GENERATED FROM PYTHON SOURCE LINES 372-375 Helper Function ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 375-387 .. code-block:: default def print_decoded(decoder, emission, param, param_value): start_time = time.monotonic() result = decoder(emission) decode_time = time.monotonic() - start_time transcript = " ".join(result[0][0].words).lower().strip() score = result[0][0].score print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)") .. GENERATED FROM PYTHON SOURCE LINES 388-396 nbest ~~~~~ This parameter indicates the number of best hypotheses to return, which is a property that is not possible with the greedy decoder. For instance, by setting ``nbest=3`` when constructing the beam search decoder earlier, we can now access the hypotheses with the top 3 scores. .. GENERATED FROM PYTHON SOURCE LINES 396-403 .. code-block:: default for i in range(3): transcript = " ".join(beam_search_result[0][i].words).strip() score = beam_search_result[0][i].score print(f"{transcript} (score: {score})") .. rst-class:: sphx-glr-script-out .. code-block:: none i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8242175269093) i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8584784734217) i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.0158622860877) .. GENERATED FROM PYTHON SOURCE LINES 404-418 beam size ~~~~~~~~~ The ``beam_size`` parameter determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, but it is computationally more expensive and does not provide additional gains beyond a certain point. In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time. .. GENERATED FROM PYTHON SOURCE LINES 418-434 .. code-block:: default beam_sizes = [1, 5, 50, 500] for beam_size in beam_sizes: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size=beam_size, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size", beam_size) .. rst-class:: sphx-glr-script-out .. code-block:: none beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.0558 secs) beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0763 secs) beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3131 secs) beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7243 secs) .. GENERATED FROM PYTHON SOURCE LINES 435-443 beam size token ~~~~~~~~~~~~~~~ The ``beam_size_token`` parameter corresponds to the number of tokens to consider for expanding each hypothesis at the decoding step. Exploring a larger number of next possible tokens increases the range of potential hypotheses at the cost of computation. .. GENERATED FROM PYTHON SOURCE LINES 443-460 .. code-block:: default num_tokens = len(tokens) beam_size_tokens = [1, 5, 10, num_tokens] for beam_size_token in beam_size_tokens: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size_token=beam_size_token, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token) .. rst-class:: sphx-glr-script-out .. code-block:: none beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.3003 secs) beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2502 secs) beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2898 secs) beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3392 secs) .. GENERATED FROM PYTHON SOURCE LINES 461-471 beam threshold ~~~~~~~~~~~~~~ The ``beam_threshold`` parameter is used to prune the stored hypotheses set at each decoding step, removing hypotheses whose scores are greater than ``beam_threshold`` away from the highest scoring hypothesis. There is a balance between choosing smaller thresholds to prune more hypotheses and reduce the search space, and choosing a large enough threshold such that plausible hypotheses are not pruned. .. GENERATED FROM PYTHON SOURCE LINES 471-487 .. code-block:: default beam_thresholds = [1, 5, 10, 25] for beam_threshold in beam_thresholds: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_threshold=beam_threshold, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold) .. rst-class:: sphx-glr-script-out .. code-block:: none beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.1470 secs) beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.1932 secs) beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4083 secs) beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4754 secs) .. GENERATED FROM PYTHON SOURCE LINES 488-497 language model weight ~~~~~~~~~~~~~~~~~~~~~ The ``lm_weight`` parameter is the weight to assign to the language model score which to accumulate with the acoustic model score for determining the overall scores. Larger weights encourage the model to predict next words based on the language model, while smaller weights give more weight to the acoustic model score instead. .. GENERATED FROM PYTHON SOURCE LINES 497-512 .. code-block:: default lm_weights = [0, LM_WEIGHT, 15] for lm_weight in lm_weights: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, lm_weight=lm_weight, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "lm weight", lm_weight) .. rst-class:: sphx-glr-script-out .. code-block:: none lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.4307 secs) lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4588 secs) lm weight 15 : was there in his was at some of what he said (score: 2918.99; 0.4375 secs) .. GENERATED FROM PYTHON SOURCE LINES 513-523 additional parameters ~~~~~~~~~~~~~~~~~~~~~ Additional parameters that can be optimized include the following - ``word_score``: score to add when word finishes - ``unk_score``: unknown word appearance score to add - ``sil_score``: silence appearance score to add - ``log_add``: whether to use log add for lexicon Trie smearing .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 28.704 seconds) .. _sphx_glr_download_tutorials_asr_inference_with_ctc_decoder_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: asr_inference_with_ctc_decoder_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: asr_inference_with_ctc_decoder_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_