torch.Tensor.scatter_reduce_#
- Tensor.scatter_reduce_(dim, index, src, reduce, *, include_self=True) Tensor#
Reduces all values from the
srctensor to the indices specified in theindextensor in theselftensor using the applied reduction defined via thereduceargument ("sum","prod","mean","amax","amin"). For each value insrc, it is reduced to an index inselfwhich is specified by its index insrcfordimension != dimand by the corresponding value inindexfordimension = dim. Ifinclude_self="True", the values in theselftensor are included in the reduction.self,indexandsrcshould all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)for all dimensionsd, and thatindex.size(d) <= self.size(d)for all dimensionsd != dim. Note thatindexandsrcdo not broadcast.For a 3-D tensor with
reduce="sum"andinclude_self=Truethe output is given as:self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
Note
This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.
Note
The backward pass is implemented only for
src.shape == index.shape.Warning
This function is in beta and may change in the near future.
- Parameters
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to scatter and reduce.
src (Tensor) – the source elements to scatter and reduce
reduce (str) – the reduction operation to apply for non-unique indices (
"sum","prod","mean","amax","amin")include_self (bool) – whether elements from the
selftensor are included in the reduction
Example:
>>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) >>> input = torch.tensor([1., 2., 3., 4.]) >>> input.scatter_reduce(0, index, src, reduce="sum") tensor([5., 14., 8., 4.]) >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) tensor([4., 12., 5., 4.]) >>> input2 = torch.tensor([5., 4., 3., 2.]) >>> input2.scatter_reduce(0, index, src, reduce="amax") tensor([5., 6., 5., 2.]) >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) tensor([3., 6., 5., 2.])