Shortcuts

Source code for torchaudio.pipelines._source_separation_pipeline

from dataclasses import dataclass
from functools import partial
from typing import Callable

import torch
import torchaudio

from torchaudio.models import conv_tasnet_base, hdemucs_high


[docs]@dataclass class SourceSeparationBundle: """Dataclass that bundles components for performing source separation. Example >>> import torchaudio >>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX >>> import torch >>> >>> # Build the separation model. >>> model = CONVTASNET_BASE_LIBRI2MIX.get_model() >>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s] >>> >>> # Instantiate the test set of Libri2Mix dataset. >>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test") >>> >>> # Apply source separation on mixture audio. >>> for i, data in enumerate(dataset): >>> sample_rate, mixture, clean_sources = data >>> # Make sure the shape of input suits the model requirement. >>> mixture = mixture.reshape(1, 1, -1) >>> estimated_sources = model(mixture) >>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration >>> print(f"Si-SNR score is : {score}.) >>> break >>> Si-SNR score is : 16.24. >>> """ _model_path: str _model_factory_func: Callable[[], torch.nn.Module] _sample_rate: int @property def sample_rate(self) -> int: """Sample rate of the audio that the model is trained on. :type: int """ return self._sample_rate
[docs] def get_model(self) -> torch.nn.Module: """Construct the model and load the pretrained weight.""" model = self._model_factory_func() path = torchaudio.utils._download_asset(self._model_path) state_dict = torch.load(path) model.load_state_dict(state_dict) model.eval() return model
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle( _model_path="models/conv_tasnet_base_libri2mix.pt", _model_factory_func=partial(conv_tasnet_base, num_sources=2), _sample_rate=8000, ) CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet* :cite:`Luo_2019` trained on *Libri2Mix dataset* :cite:`cosentino2020librimix`. The source separation model is constructed by :func:`~torchaudio.models.conv_tasnet_base` and is trained using the training script ``lightning_train.py`` `here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__ with default arguments. Please refer to :class:`SourceSeparationBundle` for usage instructions. """ HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle( _model_path="models/hdemucs_high_trained.pt", _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]), _sample_rate=44100, ) HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained music source separation pipeline with *Hybrid Demucs* :cite:`defossez2021hybrid` trained on both training and test sets of MUSDB-HQ :cite:`MUSDB18HQ` and an additional 150 extra songs from an internal database that was specifically produced for Meta. The model is constructed by :func:`~torchaudio.models.hdemucs_high`. Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__. Please refer to :class:`SourceSeparationBundle` for usage instructions. """ HDEMUCS_HIGH_MUSDB = SourceSeparationBundle( _model_path="models/hdemucs_high_musdbhq_only.pt", _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]), _sample_rate=44100, ) HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with *Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`. The model is constructed by :func:`~torchaudio.models.hdemucs_high`. Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__. Please refer to :class:`SourceSeparationBundle` for usage instructions. """

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources