torch.Tensor.scatter_¶
-
Tensor.scatter_(dim, index, src, reduce=None) → Tensor¶ Writes all values from the tensor
srcintoselfat the indices specified in theindextensor. For each value insrc, its output index is specified by its index insrcfordimension != dimand by the corresponding value inindexfordimension = dim.For a 3-D tensor,
selfis updated as:self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in
gather().self,indexandsrc(if it is a Tensor) should all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)for all dimensionsd, and thatindex.size(d) <= self.size(d)for all dimensionsd != dim. Note thatindexandsrcdo not broadcast.Moreover, as for
gather(), the values ofindexmust be between0andself.size(dim) - 1inclusive.Warning
When indices are not unique, the behavior is non-deterministic (one of the values from
srcwill be picked arbitrarily) and the gradient will be incorrect (it will be propagated to all locations in the source that correspond to the same index)!Note
The backward pass is implemented only for
src.shape == index.shape.Additionally accepts an optional
reduceargument that allows specification of an optional reduction operation, which is applied to all values in the tensorsrcintoselfat the indicies specified in theindex. For each value insrc, the reduction operation is applied to an index inselfwhich is specified by its index insrcfordimension != dimand by the corresponding value inindexfordimension = dim.Given a 3-D tensor and reduction using the multiplication operation,
selfis updated as:self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
Reducing with the addition operation is the same as using
scatter_add_().- Parameters
Example:
>>> src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]]) >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='multiply') tensor([[2.0000, 2.0000, 2.4600, 2.0000], [2.0000, 2.0000, 2.0000, 2.4600]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='add') tensor([[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]])