Note
Click here to download the full example code
Text-to-Speech with Tacotron2¶
Author: Yao-Yuan Yang, Moto Hira
Overview¶
This tutorial shows how to build text-to-speech pipeline, using the pretrained Tacotron2 in torchaudio.
The text-to-speech pipeline goes as follows:
- Text preprocessing - First, the input text is encoded into a list of symbols. In this tutorial, we will use English characters and phonemes as the symbols. 
- Spectrogram generation - From the encoded text, a spectrogram is generated. We use the - Tacotron2model for this.
- Time-domain conversion - The last step is converting the spectrogram into the waveform. The process to generate speech from spectrogram is also called a Vocoder. In this tutorial, three different vocoders are used, - WaveRNN,- GriffinLim, and Nvidia’s WaveGlow.
The following figure illustrates the whole process.
 
All the related components are bundled in torchaudio.pipelines.Tacotron2TTSBundle,
but this tutorial will also cover the process under the hood.
Preparation¶
First, we install the necessary dependencies. In addition to
torchaudio, DeepPhonemizer is required to perform phoneme-based
encoding.
%%bash
pip3 install deep_phonemizer
import torch
import torchaudio
torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.__version__)
print(torchaudio.__version__)
print(device)
2.6.0
2.6.0
cuda
import IPython
import matplotlib.pyplot as plt
Text Processing¶
Character-based encoding¶
In this section, we will go through how the character-based encoding works.
Since the pre-trained Tacotron2 model expects specific set of symbol
tables, the same functionalities is available in torchaudio. However,
we will first manually implement the encoding to aid in understanding.
First, we define the set of symbols
'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'. Then, we will map the
each character of the input text into the index of the corresponding
symbol in the table. Symbols that are not in the table are ignored.
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2]
As mentioned in the above, the symbol table and indices must match
what the pretrained Tacotron2 model expects. torchaudio provides the same
transform along with the pretrained model. You can
instantiate and use such transform as follow.
tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15,  2, 11, 31, 16, 35, 31, 11,
         31, 26, 11, 30, 27, 16, 16, 14, 19,  2]])
tensor([28], dtype=torch.int32)
Note: The output of our manual encoding and the torchaudio text_processor output matches (meaning we correctly re-implemented what the library does internally). It takes either a text or list of texts as inputs.
When a list of texts are provided, the returned lengths variable
represents the valid length of each processed tokens in the output
batch.
The intermediate representation can be retrieved as follows:
['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', ' ', 't', 'e', 'x', 't', ' ', 't', 'o', ' ', 's', 'p', 'e', 'e', 'c', 'h', '!']
Phoneme-based encoding¶
Phoneme-based encoding is similar to character-based encoding, but it uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme) model.
The detail of the G2P model is out of the scope of this tutorial, we will just look at what the conversion looks like.
Similar to the case of character-based encoding, the encoding process is
expected to match what a pretrained Tacotron2 model is trained on.
torchaudio has an interface to create the process.
The following code illustrates how to make and use the process. Behind
the scene, a G2P model is created using DeepPhonemizer package, and
the pretrained weights published by the author of DeepPhonemizer is
fetched.
  0%|          | 0.00/63.6M [00:00<?, ?B/s]
  0%|          | 128k/63.6M [00:00<01:32, 722kB/s]
  1%|          | 384k/63.6M [00:00<00:57, 1.15MB/s]
  2%|2         | 1.50M/63.6M [00:00<00:18, 3.59MB/s]
  8%|7         | 4.88M/63.6M [00:00<00:05, 12.0MB/s]
 14%|#3        | 8.62M/63.6M [00:00<00:03, 16.3MB/s]
 20%|#9        | 12.6M/63.6M [00:00<00:02, 22.6MB/s]
 27%|##7       | 17.2M/63.6M [00:01<00:01, 24.7MB/s]
 33%|###3      | 21.1M/63.6M [00:01<00:01, 28.1MB/s]
 41%|####      | 26.0M/63.6M [00:01<00:01, 29.0MB/s]
 47%|####6     | 29.9M/63.6M [00:01<00:01, 31.2MB/s]
 54%|#####4    | 34.6M/63.6M [00:01<00:00, 30.5MB/s]
 60%|#####9    | 38.1M/63.6M [00:01<00:00, 31.7MB/s]
 68%|######7   | 43.0M/63.6M [00:01<00:00, 36.5MB/s]
 73%|#######3  | 46.8M/63.6M [00:02<00:00, 31.2MB/s]
 79%|#######8  | 50.2M/63.6M [00:02<00:00, 32.5MB/s]
 85%|########5 | 54.4M/63.6M [00:02<00:00, 30.1MB/s]
 92%|#########1| 58.2M/63.6M [00:02<00:00, 32.3MB/s]
 97%|#########7| 61.9M/63.6M [00:02<00:00, 27.7MB/s]
100%|##########| 63.6M/63.6M [00:02<00:00, 25.0MB/s]
/pytorch/audio/ci_env/lib/python3.10/site-packages/dp/model/model.py:306: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=device)
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
  warnings.warn(
tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38,  2, 11, 81, 40, 64, 79, 81, 11, 81,
         20, 11, 79, 77, 59, 37,  2]])
tensor([25], dtype=torch.int32)
Notice that the encoded values are different from the example of character-based encoding.
The intermediate representation looks like the following.
['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', ' ', 'T', 'EH', 'K', 'S', 'T', ' ', 'T', 'AH', ' ', 'S', 'P', 'IY', 'CH', '!']
Spectrogram Generation¶
Tacotron2 is the model we use to generate spectrogram from the
encoded text. For the detail of the model, please refer to the
paper.
It is easy to instantiate a Tacotron2 model with pretrained weights, however, note that the input to Tacotron2 models need to be processed by the matching text processor.
torchaudio.pipelines.Tacotron2TTSBundle bundles the matching
models and processors together so that it is easy to create the pipeline.
For the available bundles, and its usage, please refer to
Tacotron2TTSBundle.
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
text = "Hello world! Text to speech!"
with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, _, _ = tacotron2.infer(processed, lengths)
_ = plt.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")

/pytorch/audio/ci_env/lib/python3.10/site-packages/dp/model/model.py:306: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=device)
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
  warnings.warn(
Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth
  0%|          | 0.00/107M [00:00<?, ?B/s]
 14%|#3        | 14.8M/107M [00:00<00:00, 155MB/s]
 27%|##7       | 29.5M/107M [00:00<00:01, 54.3MB/s]
 35%|###5      | 37.8M/107M [00:00<00:01, 51.5MB/s]
 44%|####3     | 46.9M/107M [00:00<00:01, 55.6MB/s]
 50%|####9     | 53.4M/107M [00:01<00:01, 51.0MB/s]
 60%|#####9    | 64.0M/107M [00:01<00:00, 47.1MB/s]
 73%|#######3  | 78.9M/107M [00:01<00:00, 51.7MB/s]
 78%|#######8  | 84.1M/107M [00:01<00:00, 46.9MB/s]
 89%|########9 | 96.0M/107M [00:01<00:00, 48.8MB/s]
100%|#########9| 107M/107M [00:02<00:00, 44.7MB/s]
100%|##########| 107M/107M [00:02<00:00, 49.9MB/s]
Note that Tacotron2.infer method perfoms multinomial sampling,
therefore, the process of generating the spectrogram incurs randomness.
def plot():
    fig, ax = plt.subplots(3, 1)
    for i in range(3):
        with torch.inference_mode():
            spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
        print(spec[0].shape)
        ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
plot()

torch.Size([80, 190])
torch.Size([80, 184])
torch.Size([80, 185])
Waveform Generation¶
Once the spectrogram is generated, the last process is to recover the waveform from the spectrogram using a vocoder.
torchaudio provides vocoders based on GriffinLim and
WaveRNN.
WaveRNN Vocoder¶
Continuing from the previous section, we can instantiate the matching WaveRNN model from the same bundle.
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)
text = "Hello world! Text to speech!"
with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    waveforms, lengths = vocoder(spec, spec_lengths)
/pytorch/audio/ci_env/lib/python3.10/site-packages/dp/model/model.py:306: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=device)
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
  warnings.warn(
Downloading: "https://download.pytorch.org/torchaudio/models/wavernn_10k_epochs_8bits_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/wavernn_10k_epochs_8bits_ljspeech.pth
  0%|          | 0.00/16.7M [00:00<?, ?B/s]
 89%|########9 | 14.9M/16.7M [00:00<00:00, 65.0MB/s]
100%|##########| 16.7M/16.7M [00:00<00:00, 46.2MB/s]
def plot(waveforms, spec, sample_rate):
    waveforms = waveforms.cpu().detach()
    fig, [ax1, ax2] = plt.subplots(2, 1)
    ax1.plot(waveforms[0])
    ax1.set_xlim(0, waveforms.size(-1))
    ax1.grid(True)
    ax2.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
    return IPython.display.Audio(waveforms[0:1], rate=sample_rate)
plot(waveforms, spec, vocoder.sample_rate)
 
Griffin-Lim Vocoder¶
Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate
the vocoder object with
get_vocoder()
method and pass the spectrogram.
bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)
with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
/pytorch/audio/ci_env/lib/python3.10/site-packages/dp/model/model.py:306: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=device)
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
  warnings.warn(
Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_ljspeech.pth
  0%|          | 0.00/107M [00:00<?, ?B/s]
 14%|#3        | 14.9M/107M [00:00<00:01, 65.9MB/s]
 20%|#9        | 21.2M/107M [00:00<00:01, 54.3MB/s]
 30%|##9       | 32.0M/107M [00:00<00:01, 60.0MB/s]
 44%|####3     | 46.9M/107M [00:00<00:01, 60.9MB/s]
 49%|####8     | 52.6M/107M [00:01<00:01, 38.2MB/s]
 59%|#####8    | 63.2M/107M [00:01<00:00, 50.0MB/s]
 65%|######4   | 69.8M/107M [00:01<00:00, 46.5MB/s]
 74%|#######4  | 80.0M/107M [00:01<00:00, 48.0MB/s]
 88%|########8 | 94.9M/107M [00:01<00:00, 59.6MB/s]
 94%|#########4| 101M/107M [00:02<00:00, 39.2MB/s]
 99%|#########8| 106M/107M [00:02<00:00, 38.5MB/s]
100%|##########| 107M/107M [00:02<00:00, 42.2MB/s]
 
Waveglow Vocoder¶
Waveglow is a vocoder published by Nvidia. The pretrained weights are
published on Torch Hub. One can instantiate the model using torch.hub
module.
# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load(
    "NVIDIA/DeepLearningExamples:torchhub",
    "nvidia_waveglow",
    model_math="fp32",
    pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
    "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth",  # noqa: E501
    progress=False,
    map_location=device,
)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()
with torch.no_grad():
    waveforms = waveglow.infer(spec)
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/hub.py:330: UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour
  warnings.warn(
Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip
/root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/common.py:13: UserWarning: pytorch_quantization module not found, quantization will not be available
  warnings.warn(
/root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/efficientnet.py:17: UserWarning: pytorch_quantization module not found, quantization will not be available
  warnings.warn(
/pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth" to /root/.cache/torch/hub/checkpoints/nvidia_waveglowpyt_fp32_20190306.pth
 
Total running time of the script: ( 1 minutes 13.712 seconds)