torch.nonzero

torch.nonzero(input, *, out=None, 
as_tuple=False) → LongTensor or tuple of LongTensors
  • 输入是一维张量,返回一个包含输入 input 中非零元素索引的张量,输出张量中的每行包含 input 中非零元素的索引,输出是二维张量torch.size(z,1), z 是输入张量 input 中所有非零元素的个数。
  • 输入是n维张量,如果输入 input 有 n 维,则输出的索引张量的size为torch.size(z,n) , 这里 z 是输入张量 input 中所有非零元素的个数。

无论输入是几维,输出张量都是两维,每行代表输入张量中非零元素的索引位置(在所有维度上面的位置)。 返回 input中非零元素的索引下标,n维input 中的元素的索引有n 个维度的索引下标

>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
tensor([[ 0],
        [ 1],
        [ 2],
        [ 4]])
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]))
tensor([[ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3]])
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
>>> torch.nonzero(torch.tensor(5), as_tuple=True)
(tensor([0]),)

   

posted on 2022-09-20 21:17  朴素贝叶斯  阅读(375)  评论(0编辑  收藏  举报

导航