.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/mvdr_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_mvdr_tutorial.py: Speech Enhancement with MVDR Beamforming ======================================== **Author**: `Zhaoheng Ni `__ .. GENERATED FROM PYTHON SOURCE LINES 11-31 1. Overview ----------- This is a tutorial on applying Minimum Variance Distortionless Response (MVDR) beamforming to estimate enhanced speech with TorchAudio. Steps: - Generate an ideal ratio mask (IRM) by dividing the clean/noise magnitude by the mixture magnitude. - Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`. - Estimate enhanced speech using MVDR modules (:py:func:`torchaudio.transforms.SoudenMVDR` and :py:func:`torchaudio.transforms.RTFMVDR`). - Benchmark the two methods (:py:func:`torchaudio.functional.rtf_evd` and :py:func:`torchaudio.functional.rtf_power`) for computing the relative transfer function (RTF) matrix of the reference microphone. .. GENERATED FROM PYTHON SOURCE LINES 31-43 .. code-block:: default import torch import torchaudio import torchaudio.functional as F print(torch.__version__) print(torchaudio.__version__) import matplotlib.pyplot as plt from IPython.display import Audio .. rst-class:: sphx-glr-script-out .. code-block:: none 2.10.0.dev20251013+cu126 2.8.0a0+1d65bbe .. GENERATED FROM PYTHON SOURCE LINES 44-47 2. Preparation -------------- .. GENERATED FROM PYTHON SOURCE LINES 49-51 2.1. Import the packages .. GENERATED FROM PYTHON SOURCE LINES 51-54 .. code-block:: default from torchaudio.utils import _download_asset .. GENERATED FROM PYTHON SOURCE LINES 55-75 2.2. Download audio data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The multi-channel audio example is selected from `ConferencingSpeech `__ dataset. The original filename is ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav`` which was generated with: - ``SSB07200001.wav`` from `AISHELL-3 `__ (Apache License v.2.0) - ``noise-sound-bible-0038.wav`` from `MUSAN `__ (Attribution 4.0 International — CC BY 4.0) .. GENERATED FROM PYTHON SOURCE LINES 75-81 .. code-block:: default SAMPLE_RATE = 16000 SAMPLE_CLEAN = _download_asset("tutorial-assets/mvdr/clean_speech.wav") SAMPLE_NOISE = _download_asset("tutorial-assets/mvdr/noise.wav") .. rst-class:: sphx-glr-script-out .. code-block:: none 12.8% 25.6% 38.4% 51.2% 64.0% 76.8% 89.6% 100.0% 6.4% 12.8% 19.2% 25.6% 32.0% 38.4% 44.8% 51.2% 57.6% 64.0% 70.4% 76.8% 83.2% 89.6% 96.0% 100.0% .. GENERATED FROM PYTHON SOURCE LINES 82-85 2.3. Helper functions ~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 85-154 .. code-block:: default def plot_spectrogram(stft, title="Spectrogram"): magnitude = stft.abs() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() figure, axis = plt.subplots(1, 1) img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") axis.set_title(title) plt.colorbar(img, ax=axis) def plot_mask(mask, title="Mask"): mask = mask.numpy() figure, axis = plt.subplots(1, 1) img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") axis.set_title(title) plt.colorbar(img, ax=axis) def si_snr(estimate, reference, epsilon=1e-8): estimate = estimate - estimate.mean() reference = reference - reference.mean() reference_pow = reference.pow(2).mean(axis=1, keepdim=True) mix_pow = (estimate * reference).mean(axis=1, keepdim=True) scale = mix_pow / (reference_pow + epsilon) reference = scale * reference error = estimate - reference reference_pow = reference.pow(2) error_pow = error.pow(2) reference_pow = reference_pow.mean(axis=1) error_pow = error_pow.mean(axis=1) si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) return si_snr.item() def generate_mixture(waveform_clean, waveform_noise, target_snr): power_clean_signal = waveform_clean.pow(2).mean() power_noise_signal = waveform_noise.pow(2).mean() current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal) waveform_noise *= 10 ** (-(target_snr - current_snr) / 20) return waveform_clean + waveform_noise # If you have mir_eval installed, you can use it to evaluate the separation quality of the estimated sources. # You can also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility (STOI) metric # available in the `pystoi` package, or the Perceptual Evaluation of Speech Quality (PESQ) metric available in the `pesq` package. def evaluate(estimate, reference): from pesq import pesq from pystoi import stoi import mir_eval si_snr_score = si_snr(estimate, reference) ( sdr, _, _, _, ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False) pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb") stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False) print(f"SDR score: {sdr[0]}") print(f"Si-SNR score: {si_snr_score}") print(f"PESQ score: {pesq_mix}") print(f"STOI score: {stoi_mix}") .. GENERATED FROM PYTHON SOURCE LINES 155-158 3. Generate Ideal Ratio Masks (IRMs) ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 161-164 3.1. Load audio data ~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 164-173 .. code-block:: default waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN) waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE) assert sr == sr2 == SAMPLE_RATE # The mixture waveform is a combination of clean and noise waveforms with a desired SNR. target_snr = 3 waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr) .. GENERATED FROM PYTHON SOURCE LINES 174-177 Note: To improve computational robustness, it is recommended to represent the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values. .. GENERATED FROM PYTHON SOURCE LINES 177-183 .. code-block:: default waveform_mix = waveform_mix.to(torch.double) waveform_clean = waveform_clean.to(torch.double) waveform_noise = waveform_noise.to(torch.double) .. GENERATED FROM PYTHON SOURCE LINES 184-187 3.2. Compute STFT coefficients ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 187-202 .. code-block:: default N_FFT = 1024 N_HOP = 256 stft = torchaudio.transforms.Spectrogram( n_fft=N_FFT, hop_length=N_HOP, power=None, ) istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP) stft_mix = stft(waveform_mix) stft_clean = stft(waveform_clean) stft_noise = stft(waveform_noise) .. GENERATED FROM PYTHON SOURCE LINES 203-206 3.2.1. Visualize mixture speech ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 206-212 .. code-block:: default plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)") Audio(waveform_mix[0], rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_001.png :alt: Spectrogram of Mixture Speech (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_001.png :class: sphx-glr-single-img .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 213-216 3.2.2. Visualize clean speech ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 216-221 .. code-block:: default plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)") Audio(waveform_clean[0], rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_002.png :alt: Spectrogram of Clean Speech (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_002.png :class: sphx-glr-single-img .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 222-225 3.2.3. Visualize noise ^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 225-230 .. code-block:: default plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)") Audio(waveform_noise[0], rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_003.png :alt: Spectrogram of Noise (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_003.png :class: sphx-glr-single-img .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 231-239 3.3. Define the reference microphone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We choose the first microphone in the array as the reference channel for demonstration. The selection of the reference channel may depend on the design of the microphone array. You can also apply an end-to-end neural network which estimates both the reference channel and the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module. .. GENERATED FROM PYTHON SOURCE LINES 239-243 .. code-block:: default REFERENCE_CHANNEL = 0 .. GENERATED FROM PYTHON SOURCE LINES 244-247 3.4. Compute IRMs ~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 247-260 .. code-block:: default def get_irms(stft_clean, stft_noise): mag_clean = stft_clean.abs() ** 2 mag_noise = stft_noise.abs() ** 2 irm_speech = mag_clean / (mag_clean + mag_noise) irm_noise = mag_noise / (mag_clean + mag_noise) return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL] irm_speech, irm_noise = get_irms(stft_clean, stft_noise) .. GENERATED FROM PYTHON SOURCE LINES 261-264 3.4.1. Visualize IRM of target speech ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 264-268 .. code-block:: default plot_mask(irm_speech, "IRM of the Target Speech") .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_004.png :alt: IRM of the Target Speech :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 269-272 3.4.2. Visualize IRM of noise ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 272-275 .. code-block:: default plot_mask(irm_noise, "IRM of the Noise") .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_005.png :alt: IRM of the Noise :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 276-284 4. Compute PSD matrices ----------------------- :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given the multi-channel complex-valued STFT coefficients of the mixture speech and the time-frequency mask. The shape of the PSD matrix is `(..., freq, channel, channel)`. .. GENERATED FROM PYTHON SOURCE LINES 284-291 .. code-block:: default psd_transform = torchaudio.transforms.PSD() psd_speech = psd_transform(stft_mix, irm_speech) psd_noise = psd_transform(stft_mix, irm_noise) .. GENERATED FROM PYTHON SOURCE LINES 292-295 5. Beamforming using SoudenMVDR ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 298-308 5.1. Apply beamforming ~~~~~~~~~~~~~~~~~~~~~~ :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel complexed-valued STFT coefficients of the mixture speech, PSD matrices of target speech and noise, and the reference channel inputs. The output is a single-channel complex-valued STFT coefficients of the enhanced speech. We can then obtain the enhanced waveform by passing this output to the :py:func:`torchaudio.transforms.InverseSpectrogram` module. .. GENERATED FROM PYTHON SOURCE LINES 308-314 .. code-block:: default mvdr_transform = torchaudio.transforms.SoudenMVDR() stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL) waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1]) .. GENERATED FROM PYTHON SOURCE LINES 315-318 5.2. Result for SoudenMVDR ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 318-324 .. code-block:: default plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)") waveform_souden = waveform_souden.reshape(1, -1) Audio(waveform_souden, rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_006.png :alt: Enhanced Spectrogram by SoudenMVDR (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_006.png :class: sphx-glr-single-img .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 325-328 6. Beamforming using RTFMVDR ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 331-343 6.1. Compute RTF ~~~~~~~~~~~~~~~~ TorchAudio offers two methods for computing the RTF matrix of a target speech: - :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue decomposition to the PSD matrix of target speech to get the RTF matrix. - :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration method. You can specify the number of iterations with argument ``n_iter``. .. GENERATED FROM PYTHON SOURCE LINES 343-348 .. code-block:: default rtf_evd = F.rtf_evd(psd_speech) rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL) .. GENERATED FROM PYTHON SOURCE LINES 349-359 6.2. Apply beamforming ~~~~~~~~~~~~~~~~~~~~~~ :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech, PSD matrix of noise, and the reference channel inputs. The output is a single-channel complex-valued STFT coefficients of the enhanced speech. We can then obtain the enhanced waveform by passing this output to the :py:func:`torchaudio.transforms.InverseSpectrogram` module. .. GENERATED FROM PYTHON SOURCE LINES 359-371 .. code-block:: default mvdr_transform = torchaudio.transforms.RTFMVDR() # compute the enhanced speech based on F.rtf_evd stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL) waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1]) # compute the enhanced speech based on F.rtf_power stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL) waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1]) .. GENERATED FROM PYTHON SOURCE LINES 372-375 6.3. Result for RTFMVDR with `rtf_evd` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 375-381 .. code-block:: default plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)") waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1) Audio(waveform_rtf_evd, rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_007.png :alt: Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_007.png :class: sphx-glr-single-img .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 382-385 6.4. Result for RTFMVDR with `rtf_power` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 385-389 .. code-block:: default plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)") waveform_rtf_power = waveform_rtf_power.reshape(1, -1) Audio(waveform_rtf_power, rate=SAMPLE_RATE) .. image-sg:: /tutorials/images/sphx_glr_mvdr_tutorial_008.png :alt: Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB) :srcset: /tutorials/images/sphx_glr_mvdr_tutorial_008.png :class: sphx-glr-single-img .. raw:: html


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.653 seconds) .. _sphx_glr_download_tutorials_mvdr_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: mvdr_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: mvdr_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_