torch.max

第一个版本

torch.max(input) → Tensor

Returns the maximum value of all elements in the input tensor.

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)

第二个版本

torch.max(input, dim, keepdim=False, *, out=None)
  • Returns a namedtuple (values, indices) ,where 
    • values is the maximum value of each row of the input tensor in the given dimension dim.
    • And indices is the index location of each maximum value found (argmax).
  • If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensors having 1 fewer dimension than input.

If there are multiple maximal values in a reduced row then the indices of the first maximal value are returned.  

Parameters

  • input (Tensor) – the input tensor.

  • dim (int) – the dimension to reduce.

  • keepdim (bool) – whether the output tensor has dim retained or not. Default: False.

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

  

   

posted on 2022-09-20 21:00  朴素贝叶斯  阅读(50)  评论(0编辑  收藏  举报

导航