torch.argmax¶
-
torch.argmax(input) → LongTensor¶ Returns the indices of the maximum value of all elements in the
inputtensor.This is the second value returned by
torch.max(). See its documentation for the exact semantics of this method.Note
If there are multiple maximal values then the indices of the first maximal value are returned.
- Parameters
input (Tensor) – the input tensor.
Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a) tensor(0)
-
torch.argmax(input, dim, keepdim=False) → LongTensor
Returns the indices of the maximum values of a tensor across a dimension.
This is the second value returned by
torch.max(). See its documentation for the exact semantics of this method.- Parameters
Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])