Pytorch中的高级选择函数

  参考资料:

  https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter02_prerequisite/2.2_tensor?id=_222-%e6%93%8d%e4%bd%9c

  https://pytorch.org/docs/stable/index.html

  深度学习里,很多时候我们只想取输出中的一部分值,此时便用上了Pytorch中的高级索引函数。我们见过最多的可能就是torch.gather这个函数了。这个随笔讲解一下Pytorch中的高级选择函数。

  一、torch.index_select

torch.index_select(input, dim, index, *, out=None)-> Tensor
"""
:param input(Tensor) - the input tensor
:param dim(int) - the dimension in which we index
:param index(IntTensor or LongTensor) - the 1-D tensor containing the indices to index
:output out(Tensor,optional) - the output tensor
"""

  

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

  从这个官方示例可以看出,torch.index是针对某一个维度进行选择的。在示例代码中选择了dim=0(行)。然后使用一个可变长度的indices,选取indices对应的行数。

  二、torch.masked_select

  这个函数就更为简单粗暴了,从名字就可以看出来它是使用一个蒙版来选择Tensor中的值。也容易想到这个mask tensor需要和input tensor的形状保持一致。但是需要注意的是这个函数的输出是一个一维的Tensor,保存了从原始的Tensor中选择出来的所有值。

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
        [-1.2035,  1.2252,  0.5002,  0.6248],
        [ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
        [False, True, True, True],
        [False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

  小知识点是这里使用了一个ge函数。该函数会逐元素比较array中的值和给定值的大小。然后返回布尔类型的tensor。总之要想用好masked_select,和各种能判断并生成布尔类型tensor的函数搭配起来才是正道。

  三、torch.gather

  这个函数有点绕,建议直接去看官方的文档,说明的比较清楚。

  https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

  

 

   我对于这个函数的理解是,首先有一个src_tensor和一个index_tensor,index_tensor和src_tensro有相同的维度数,但是在每一个维度上,应该有d(index_tensor) <= d(src_tensor)。

  首先看最简单的情况即index_tensor和src_tensor有完全相同的形状。按照文档上的说明,白话式的解释就是:对于输出矩阵的每一个位置,数据源还是来自src_tensor。只是在指定的那一个维度上,只取响应index的值。拿这个范例来说,[0, 0 ]这个位置,我们的数据源还是src_tensor,只不过数据源变广了,变成了src_tensor[0,:]即第一行中所有列的元素。然后再看index_tensor,这个位置是0。于是这个位置就被赋成饿了src_tensor[0, 0]。

  关于index_tensor比src_tensor小的情况。我最开始的理解就是使用了广播机制首先变换到一样的形状上,但是这是错误的,下面做一个和官方示例相似的小实验测试一下。

>>>t = torch.tensor([[1, 2], [3, 4]])
>>>torch.gather(t, 1, torch.tensor([[0, 0]]))
tensor([[1, 1]])

  我们发现输出的结果其实是和index_tensor的形状一致的。也就是舍弃掉index_tensor没有覆盖的位置。

posted @ 2021-10-22 14:53  思念殇千寻  阅读(298)  评论(0编辑  收藏  举报