ConvEmformer¶
- class torchaudio.prototype.models.ConvEmformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0, conv_activation: str = 'silu')[source]¶
DEPRECATED
Warning
This class is deprecated from version 2.8. It will be removed in the 2.9 release. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. Please see https://github.com/pytorch/audio/issues/3902 for more information.
- Implements the convolution-augmented streaming transformer architecture introduced in
Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution [Shi et al., 2022].
- Args:
input_dim (int): input dimension. num_heads (int): number of attention heads in each ConvEmformer layer. ffn_dim (int): hidden layer dimension of each ConvEmformer layer’s feedforward network. num_layers (int): number of ConvEmformer layers to instantiate. segment_length (int): length of each input segment. kernel_size (int): size of kernel to use in convolution modules. dropout (float, optional): dropout probability. (Default: 0.0) ffn_activation (str, optional): activation function to use in feedforward networks.
Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)
left_context_length (int, optional): length of left context. (Default: 0) right_context_length (int, optional): length of right context. (Default: 0) max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
strategy. Must be one of (“depthwise”, “constant”,
None
). (Default: “depthwise”)tanh_on_mem (bool, optional): if
True
, applies tanh to memory elements. (Default:False
) negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) conv_activation (str, optional): activation function to use in convolution modules.Must be one of (“relu”, “gelu”, “silu”). (Default: “silu”)
- Examples:
>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4) >>> input = torch.rand(10, 200, 80) >>> lengths = torch.randint(1, 200, (10,)) >>> output, lengths = conv_emformer(input, lengths) >>> input = torch.rand(4, 20, 80) >>> lengths = torch.ones(4) * 20 >>> output, lengths, states = conv_emformer.infer(input, lengths, None)
Methods¶
forward¶
- ConvEmformer.forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor] ¶
Forward pass for training and non-streaming inference.
B: batch size; T: max number of input frames in batch; D: feature dimension of each frame.
- Parameters
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in
input
.
- Returns
- Tensor
output frames, with shape (B, T, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- Return type
(Tensor, Tensor)
infer¶
- ConvEmformer.infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List[Tensor]]] ¶
Forward pass for streaming inference.
B: batch size; D: feature dimension of each frame.
- Parameters
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, segment_length + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
input
.states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing internal state generated in preceding invocation of
infer
. (Default:None
)
- Returns
- Tensor
output frames, with shape (B, segment_length, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- List[List[Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
infer
.
- Return type
(Tensor, Tensor, List[List[Tensor]])