.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/asr_inference_with_cuda_ctc_decoder_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_asr_inference_with_cuda_ctc_decoder_tutorial.py: ASR Inference with CUDA CTC Decoder ==================================== **Author**: `Yuekai Zhang `__ .. warning:: Starting with version 2.8, we are refactoring TorchAudio to transition it into a maintenance phase. As a result: - The APIs described in this tutorial are deprecated in 2.8 and will be removed in 2.9. - The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. This tutorial shows how to perform speech recognition inference using a CUDA-based CTC beam search decoder. We demonstrate this on a pretrained `Zipformer `__ model from `Next-gen Kaldi `__ project. .. GENERATED FROM PYTHON SOURCE LINES 26-45 Overview -------- Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. The underlying implementation uses cuda to acclerate the whole decoding process A mathematical formula for the decoder can be found in the `paper `__, and a more detailed algorithm can be found in this `blog `__. Running ASR inference using a CUDA CTC Beam Search decoder requires the following components - Acoustic Model: model predicting modeling units (BPE in this tutorial) from acoustic features - BPE Model: the byte-pair encoding (BPE) tokenizer file .. GENERATED FROM PYTHON SOURCE LINES 48-54 Acoustic Model and Set Up ------------------------- First we import the necessary utilities and fetch the data that we are working with .. GENERATED FROM PYTHON SOURCE LINES 54-61 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.8.0+cu126 2.8.0 .. GENERATED FROM PYTHON SOURCE LINES 63-72 .. code-block:: default import time from pathlib import Path import IPython import sentencepiece as spm from torchaudio.models.decoder import cuda_ctc_decoder from torchaudio.utils import download_asset .. GENERATED FROM PYTHON SOURCE LINES 73-78 We use the pretrained `Zipformer `__ model that is trained on the `LibriSpeech dataset `__. The model is jointly trained with CTC and Transducer loss functions. In this tutorial, we only use CTC head of the model. .. GENERATED FROM PYTHON SOURCE LINES 79-94 .. code-block:: default def download_asset_external(url, key): path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key) if not path.exists(): path.parent.mkdir(parents=True, exist_ok=True) torch.hub.download_url_to_file(url, path) return str(path) url_prefix = "https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01" model_link = f"{url_prefix}/resolve/main/exp/cpu_jit.pt" model_path = download_asset_external(model_link, "cuda_ctc_decoder/cpu_jit.pt") .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/269M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 105-111 The transcript corresponding to this audio file is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 114-119 Files and Data for Decoder -------------------------- Next, we load in our token from BPE model, which is the tokenizer for decoding. .. GENERATED FROM PYTHON SOURCE LINES 122-144 Tokens ~~~~~~ The tokens are the possible symbols that the acoustic model can predict, including the blank symbol in CTC. In this tutorial, it includes 500 BPE tokens. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index. .. code-block:: # tokens S _THE _A T _AND ... .. GENERATED FROM PYTHON SOURCE LINES 144-152 .. code-block:: default bpe_link = f"{url_prefix}/resolve/main/data/lang_bpe_500/bpe.model" bpe_path = download_asset_external(bpe_link, "cuda_ctc_decoder/bpe.model") bpe_model = spm.SentencePieceProcessor() bpe_model.load(bpe_path) tokens = [bpe_model.id_to_piece(id) for id in range(bpe_model.get_piece_size())] print(tokens) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/239k [00:00', '', '', 'S', '▁THE', '▁A', 'T', '▁AND', 'ED', '▁OF', '▁TO', 'E', 'D', 'N', 'ING', '▁IN', 'Y', 'M', 'C', '▁I', 'A', 'P', '▁HE', 'R', 'O', 'L', 'RE', 'I', 'U', 'ER', '▁IT', 'LY', '▁THAT', '▁WAS', '▁', '▁S', 'AR', '▁BE', 'F', '▁C', 'IN', 'B', '▁FOR', 'OR', 'LE', "'", '▁HIS', '▁YOU', 'AL', '▁RE', 'V', '▁B', 'G', 'RI', '▁E', '▁WITH', '▁T', '▁AS', 'LL', '▁P', '▁HER', 'ST', '▁HAD', '▁SO', '▁F', 'W', 'CE', '▁IS', 'ND', '▁NOT', 'TH', '▁BUT', 'EN', '▁SHE', '▁ON', 'VE', 'ON', 'SE', '▁DE', 'UR', '▁G', 'CH', 'K', 'TER', '▁AT', 'IT', '▁ME', 'RO', 'NE', 'RA', 'ES', 'IL', 'NG', 'IC', '▁NO', '▁HIM', 'ENT', 'IR', '▁WE', 'H', '▁DO', '▁ALL', '▁HAVE', 'LO', '▁BY', '▁MY', '▁MO', '▁THIS', 'LA', '▁ST', '▁WHICH', '▁CON', '▁THEY', 'CK', 'TE', '▁SAID', '▁FROM', '▁GO', '▁WHO', '▁TH', '▁OR', '▁D', '▁W', 'VER', 'LI', '▁SE', '▁ONE', '▁CA', '▁AN', '▁LA', '▁WERE', 'EL', '▁HA', '▁MAN', '▁FA', '▁EX', 'AD', '▁SU', 'RY', '▁MI', 'AT', '▁BO', '▁WHEN', 'AN', 'THER', 'PP', 'ATION', '▁FI', '▁WOULD', '▁PRO', 'OW', 'ET', '▁O', '▁THERE', '▁HO', 'ION', '▁WHAT', '▁FE', '▁PA', 'US', 'MENT', '▁MA', 'UT', '▁OUT', '▁THEIR', '▁IF', '▁LI', '▁K', '▁WILL', '▁ARE', 'ID', '▁RO', 'DE', 'TION', '▁WA', 'PE', '▁UP', '▁SP', '▁PO', 'IGHT', '▁UN', 'RU', '▁LO', 'AS', 'OL', '▁LE', '▁BEEN', '▁SH', '▁RA', '▁SEE', 'KE', 'UL', 'TED', '▁SA', 'UN', 'UND', 'ANT', '▁NE', 'IS', '▁THEM', 'CI', 'GE', '▁COULD', '▁DIS', 'OM', 'ISH', 'HE', 'EST', '▁SOME', 'ENCE', 'ITY', 'IVE', '▁US', '▁MORE', '▁EN', 'ARD', 'ATE', '▁YOUR', '▁INTO', '▁KNOW', '▁CO', 'ANCE', '▁TIME', '▁WI', '▁YE', 'AGE', '▁NOW', 'TI', 'FF', 'ABLE', '▁VERY', '▁LIKE', 'AM', 'HI', 'Z', '▁OTHER', '▁THAN', '▁LITTLE', '▁DID', '▁LOOK', 'TY', 'ERS', '▁CAN', '▁CHA', '▁AR', 'X', 'FUL', 'UGH', '▁BA', '▁DAY', '▁ABOUT', 'TEN', 'IM', '▁ANY', '▁PRE', '▁OVER', 'IES', 'NESS', 'ME', 'BLE', '▁M', 'ROW', '▁HAS', '▁GREAT', '▁VI', 'TA', '▁AFTER', 'PER', '▁AGAIN', 'HO', 'SH', '▁UPON', '▁DI', '▁HAND', '▁COM', 'IST', 'TURE', '▁STA', '▁THEN', '▁SHOULD', '▁GA', 'OUS', 'OUR', '▁WELL', '▁ONLY', 'MAN', '▁GOOD', '▁TWO', '▁MAR', '▁SAY', '▁HU', 'TING', '▁OUR', 'RESS', '▁DOWN', 'IOUS', '▁BEFORE', '▁DA', '▁NA', 'QUI', '▁MADE', '▁EVERY', '▁OLD', '▁EVEN', 'IG', '▁COME', '▁GRA', '▁RI', '▁LONG', 'OT', 'SIDE', 'WARD', '▁FO', '▁WHERE', 'MO', 'LESS', '▁SC', '▁MUST', '▁NEVER', '▁HOW', '▁CAME', '▁SUCH', '▁RU', '▁TAKE', '▁WO', '▁CAR', 'UM', 'AK', '▁THINK', '▁MUCH', '▁MISTER', '▁MAY', '▁JO', '▁WAY', '▁COMP', '▁THOUGHT', '▁STO', '▁MEN', '▁BACK', '▁DON', 'J', '▁LET', '▁TRA', '▁FIRST', '▁JUST', '▁VA', '▁OWN', '▁PLA', '▁MAKE', 'ATED', '▁HIMSELF', '▁WENT', '▁PI', 'GG', 'RING', '▁DU', '▁MIGHT', '▁PART', '▁GIVE', '▁IMP', '▁BU', '▁PER', '▁PLACE', '▁HOUSE', '▁THROUGH', 'IAN', '▁SW', '▁UNDER', 'QUE', '▁AWAY', '▁LOVE', 'QUA', '▁LIFE', '▁GET', '▁WITHOUT', '▁PASS', '▁TURN', 'IGN', '▁HEAD', '▁MOST', '▁THOSE', '▁SHALL', '▁EYES', '▁COL', '▁STILL', '▁NIGHT', '▁NOTHING', 'ITION', 'HA', '▁TELL', '▁WORK', '▁LAST', '▁NEW', '▁FACE', '▁HI', '▁WORD', '▁FOUND', '▁COUNT', '▁OB', '▁WHILE', '▁SHA', '▁MEAN', '▁SAW', '▁PEOPLE', '▁FRIEND', '▁THREE', '▁ROOM', '▁SAME', '▁THOUGH', '▁RIGHT', '▁CHILD', '▁FATHER', '▁ANOTHER', '▁HEART', '▁WANT', '▁TOOK', 'OOK', '▁LIGHT', '▁MISSUS', '▁OPEN', '▁JU', '▁ASKED', 'PORT', '▁LEFT', '▁JA', '▁WORLD', '▁HOME', '▁WHY', '▁ALWAYS', '▁ANSWER', '▁SEEMED', '▁SOMETHING', '▁GIRL', '▁BECAUSE', '▁NAME', '▁TOLD', '▁NI', '▁HIGH', 'IZE', '▁WOMAN', '▁FOLLOW', '▁RETURN', '▁KNEW', '▁EACH', '▁KIND', '▁JE', '▁ACT', '▁LU', '▁CERTAIN', '▁YEARS', '▁QUITE', '▁APPEAR', '▁BETTER', '▁HALF', '▁PRESENT', '▁PRINCE', 'SHIP', '▁ALSO', '▁BEGAN', '▁HAVING', '▁ENOUGH', '▁PERSON', '▁LADY', '▁WHITE', '▁COURSE', '▁VOICE', '▁SPEAK', '▁POWER', '▁MORNING', '▁BETWEEN', '▁AMONG', '▁KEEP', '▁WALK', '▁MATTER', '▁TEA', '▁BELIEVE', '▁SMALL', '▁TALK', '▁FELT', '▁HORSE', '▁MYSELF', '▁SIX', '▁HOWEVER', '▁FULL', '▁HERSELF', '▁POINT', '▁STOOD', '▁HUNDRED', '▁ALMOST', '▁SINCE', '▁LARGE', '▁LEAVE', '▁PERHAPS', '▁DARK', '▁SUDDEN', '▁REPLIED', '▁ANYTHING', '▁WONDER', '▁UNTIL', 'Q'] .. GENERATED FROM PYTHON SOURCE LINES 153-159 Construct CUDA Decoder ---------------------- In this tutorial, we will construct a CUDA beam search decoder. The decoder can be constructed using the factory function :py:func:`~torchaudio.models.decoder.cuda_ctc_decoder`. .. GENERATED FROM PYTHON SOURCE LINES 159-161 .. code-block:: default cuda_decoder = cuda_ctc_decoder(tokens, nbest=10, beam_size=10, blank_skip_threshold=0.95) .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/examples/tutorials/asr_inference_with_cuda_ctc_decoder_tutorial.py:160: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.cuda_ctc_decoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. cuda_decoder = cuda_ctc_decoder(tokens, nbest=10, beam_size=10, blank_skip_threshold=0.95) /pytorch/audio/src/torchaudio/models/decoder/_cuda_ctc_decoder.py:187: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.CUCTCDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold) .. GENERATED FROM PYTHON SOURCE LINES 162-176 Run Inference ------------- Now that we have the data, acoustic model, and decoder, we can perform inference. The output of the beam search decoder is of type :py:class:`~torchaudio.models.decoder.CUCTCHypothesis`, consisting of the predicted token IDs, words (symbols corresponding to the token IDs), and hypothesis scores. Recall the transcript corresponding to the waveform is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 176-197 .. code-block:: default actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said" actual_transcript = actual_transcript.split() device = torch.device("cuda", 0) acoustic_model = torch.jit.load(model_path) acoustic_model.to(device) acoustic_model.eval() waveform = waveform.to(device) feat = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=80, snip_edges=False) feat = feat.unsqueeze(0) feat_lens = torch.tensor(feat.size(1), device=device).unsqueeze(0) encoder_out, encoder_out_lens = acoustic_model.encoder(feat, feat_lens) nnet_output = acoustic_model.ctc_output(encoder_out) log_prob = torch.nn.functional.log_softmax(nnet_output, -1) print(f"The shape of log_prob: {log_prob.shape}, the shape of encoder_out_lens: {encoder_out_lens.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none The shape of log_prob: torch.Size([1, 175, 500]), the shape of encoder_out_lens: torch.Size([1]) .. GENERATED FROM PYTHON SOURCE LINES 198-200 The cuda ctc decoder gives the following result. .. GENERATED FROM PYTHON SOURCE LINES 200-210 .. code-block:: default results = cuda_decoder(log_prob, encoder_out_lens.to(torch.int32)) beam_search_transcript = bpe_model.decode(results[0][0].tokens).lower() beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_transcript.split()) / len( actual_transcript ) print(f"Transcript: {beam_search_transcript}") print(f"WER: {beam_search_wer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i really was very much afraid of showing him how much shocked i was at some parts of what he said WER: 0.0 .. GENERATED FROM PYTHON SOURCE LINES 211-219 Beam Search Decoder Parameters ------------------------------ In this section, we go a little bit more in depth about some different parameters and tradeoffs. For the full list of customizable parameters, please refer to the :py:func:`documentation `. .. GENERATED FROM PYTHON SOURCE LINES 222-225 Helper Function ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 225-236 .. code-block:: default def print_decoded(cuda_decoder, bpe_model, log_prob, encoder_out_lens, param, param_value): start_time = time.monotonic() results = cuda_decoder(log_prob, encoder_out_lens.to(torch.int32)) decode_time = time.monotonic() - start_time transcript = bpe_model.decode(results[0][0].tokens).lower() score = results[0][0].score print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)") .. GENERATED FROM PYTHON SOURCE LINES 237-244 nbest ~~~~~ This parameter indicates the number of best hypotheses to return. For instance, by setting ``nbest=10`` when constructing the beam search decoder earlier, we can now access the hypotheses with the top 10 scores. .. GENERATED FROM PYTHON SOURCE LINES 244-251 .. code-block:: default for i in range(10): transcript = bpe_model.decode(results[0][i].tokens).lower() score = results[0][i].score print(f"{transcript} (score: {score})") .. rst-class:: sphx-glr-script-out .. code-block:: none i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.20293806493282318) i really was very much afraid of showing him how much shocked i was at some part of what he said (score: -1.740274429321289) i really was very much afraid of sheowing him how much shocked i was at some parts of what he said (score: -6.679695129394531) i reallyly very much afraid of showing him how much shocked i was at some parts of what he said (score: -7.597385883331299) i really was very much afraid of sheowing him how much shocked i was at some part of what he said (score: -8.223531723022461) i really was very much afraid of shwing him how much shocked i was at some parts of what he said (score: -8.439902305603027) i really was very much afraid of showing him how much shocked i was in some parts of what he said (score: -8.782021522521973) i really was very much afraid of showing him how much shocked i was at some parts of what said (score: -8.884115219116211) i really was very much afraid of showing him how much shocked i was at some partes of what he said (score: -8.99947452545166) i really was very much afraid of showing him how much shocked i was at some parts of what he say (score: -9.13825798034668) .. GENERATED FROM PYTHON SOURCE LINES 252-265 beam size ~~~~~~~~~ The ``beam_size`` parameter determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, but it does not provide additional gains beyond a certain point. We recommend to set beam_size=10 for cuda beam search decoder. In the example below, we see improvement in decoding quality as we increase beam size from 1 to 3, but notice how using a beam size of 3 provides the same output as beam size 10. .. GENERATED FROM PYTHON SOURCE LINES 265-278 .. code-block:: default beam_sizes = [1, 2, 3, 10] for beam_size in beam_sizes: beam_search_decoder = cuda_ctc_decoder( tokens, nbest=1, beam_size=beam_size, blank_skip_threshold=0.95, ) print_decoded(beam_search_decoder, bpe_model, log_prob, encoder_out_lens, "beam size", beam_size) .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/examples/tutorials/asr_inference_with_cuda_ctc_decoder_tutorial.py:269: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.cuda_ctc_decoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. beam_search_decoder = cuda_ctc_decoder( /pytorch/audio/src/torchaudio/models/decoder/_cuda_ctc_decoder.py:187: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.CUCTCDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold) beam size 1 : i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -1.35; 0.0009 secs) beam size 2 : i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.21; 0.0009 secs) beam size 3 : i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.20; 0.0009 secs) beam size 10 : i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.20; 0.0010 secs) .. GENERATED FROM PYTHON SOURCE LINES 279-289 blank skip threshold ~~~~~~~~~~~~~~~~~~~~ The ``blank_skip_threshold`` parameter is used to prune the frames which have large blank probability. Pruning these frames with a good blank_skip_threshold could speed up decoding process a lot while no accuracy drop. Since the rule of CTC, we would keep at least one blank frame between two non-blank frames to avoid mistakenly merge two consecutive identical symbols. We recommend to set blank_skip_threshold=0.95 for cuda beam search decoder. .. GENERATED FROM PYTHON SOURCE LINES 289-303 .. code-block:: default blank_skip_probs = [0.25, 0.95, 1.0] for blank_skip_prob in blank_skip_probs: beam_search_decoder = cuda_ctc_decoder( tokens, nbest=10, beam_size=10, blank_skip_threshold=blank_skip_prob, ) print_decoded(beam_search_decoder, bpe_model, log_prob, encoder_out_lens, "blank_skip_threshold", blank_skip_prob) del cuda_decoder .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/examples/tutorials/asr_inference_with_cuda_ctc_decoder_tutorial.py:293: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.cuda_ctc_decoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. beam_search_decoder = cuda_ctc_decoder( /pytorch/audio/src/torchaudio/models/decoder/_cuda_ctc_decoder.py:187: UserWarning: torchaudio.models.decoder._cuda_ctc_decoder.CUCTCDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold) blank_skip_threshold 0.25: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: -0.01; 0.0009 secs) blank_skip_threshold 0.95: i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.20; 0.0010 secs) blank_skip_threshold 1.0: i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: -0.21; 0.0044 secs) .. GENERATED FROM PYTHON SOURCE LINES 304-322 Benchmark with flashlight CPU decoder ------------------------------------- We benchmark the throughput and accuracy between CUDA decoder and CPU decoder using librispeech test_other set. To reproduce below benchmark results, you may refer `here `__. +--------------+------------------------------------------+---------+-----------------------+-----------------------------+ | Decoder | Setting | WER (%) | N-Best Oracle WER (%) | Decoder Cost Time (seconds) | +==============+==========================================+=========+=======================+=============================+ | CUDA decoder | blank_skip_threshold 0.95 | 5.81 | 4.11 | 2.57 | +--------------+------------------------------------------+---------+-----------------------+-----------------------------+ | CUDA decoder | blank_skip_threshold 1.0 (no frame-skip) | 5.81 | 4.09 | 6.24 | +--------------+------------------------------------------+---------+-----------------------+-----------------------------+ | CPU decoder | beam_size_token 10 | 5.86 | 4.30 | 28.61 | +--------------+------------------------------------------+---------+-----------------------+-----------------------------+ | CPU decoder | beam_size_token 500 | 5.86 | 4.30 | 791.80 | +--------------+------------------------------------------+---------+-----------------------+-----------------------------+ From the above table, CUDA decoder could give a slight improvement in WER and a significant increase in throughput. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 2.002 seconds) .. _sphx_glr_download_tutorials_asr_inference_with_cuda_ctc_decoder_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: asr_inference_with_cuda_ctc_decoder_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: asr_inference_with_cuda_ctc_decoder_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_