torch.argmax#
- torch.argmax(input) LongTensor#
- Returns the indices of the maximum value of all elements in the - inputtensor.- This is the second value returned by - torch.max(). See its documentation for the exact semantics of this method.- Note - If there are multiple maximal values then the indices of the first maximal value are returned. - Parameters
- input (Tensor) – the input tensor. 
 - Example: - >>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a) tensor(0) - torch.argmax(input, dim, keepdim=False) LongTensor
 - Returns the indices of the maximum values of a tensor across a dimension. - This is the second value returned by - torch.max(). See its documentation for the exact semantics of this method.- Parameters
- input (Tensor) – the input tensor. 
- dim – the dimension to reduce. 
 
 - Example: - >>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])