torch.bmm#
- torch.bmm(input, mat2, out_dtype=None, *, out=None) Tensor#
Performs a batch matrix-matrix product of matrices stored in
inputandmat2.inputandmat2must be 3-D tensors each containing the same number of matrices.If
inputis a tensor,mat2is a tensor,outwill be a tensor.This operator supports TensorFloat32.
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Note
This function does not broadcast. For broadcasting matrix products, see
torch.matmul().- Parameters
- Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])