Pytorch中的高级选择函数
参考资料:
https://pytorch.org/docs/stable/index.html
深度学习里,很多时候我们只想取输出中的一部分值,此时便用上了Pytorch中的高级索引函数。我们见过最多的可能就是torch.gather这个函数了。这个随笔讲解一下Pytorch中的高级选择函数。
一、torch.index_select
torch.index_select(input, dim, index, *, out=None)-> Tensor """ :param input(Tensor) - the input tensor :param dim(int) - the dimension in which we index :param index(IntTensor or LongTensor) - the 1-D tensor containing the indices to index :output out(Tensor,optional) - the output tensor """
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])
从这个官方示例可以看出,torch.index是针对某一个维度进行选择的。在示例代码中选择了dim=0(行)。然后使用一个可变长度的indices,选取indices对应的行数。
二、torch.masked_select
这个函数就更为简单粗暴了,从名字就可以看出来它是使用一个蒙版来选择Tensor中的值。也容易想到这个mask tensor需要和input tensor的形状保持一致。但是需要注意的是这个函数的输出是一个一维的Tensor,保存了从原始的Tensor中选择出来的所有值。
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], [-1.2035, 1.2252, 0.5002, 0.6248], [ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[False, False, False, False], [False, True, True, True], [False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
小知识点是这里使用了一个ge函数。该函数会逐元素比较array中的值和给定值的大小。然后返回布尔类型的tensor。总之要想用好masked_select,和各种能判断并生成布尔类型tensor的函数搭配起来才是正道。
三、torch.gather
这个函数有点绕,建议直接去看官方的文档,说明的比较清楚。
https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
我对于这个函数的理解是,首先有一个src_tensor和一个index_tensor,index_tensor和src_tensro有相同的维度数,但是在每一个维度上,应该有d(index_tensor) <= d(src_tensor)。
首先看最简单的情况即index_tensor和src_tensor有完全相同的形状。按照文档上的说明,白话式的解释就是:对于输出矩阵的每一个位置,数据源还是来自src_tensor。只是在指定的那一个维度上,只取响应index的值。拿这个范例来说,[0, 0 ]这个位置,我们的数据源还是src_tensor,只不过数据源变广了,变成了src_tensor[0,:]即第一行中所有列的元素。然后再看index_tensor,这个位置是0。于是这个位置就被赋成饿了src_tensor[0, 0]。
关于index_tensor比src_tensor小的情况。我最开始的理解就是使用了广播机制首先变换到一样的形状上,但是这是错误的,下面做一个和官方示例相似的小实验测试一下。
>>>t = torch.tensor([[1, 2], [3, 4]]) >>>torch.gather(t, 1, torch.tensor([[0, 0]])) tensor([[1, 1]])
我们发现输出的结果其实是和index_tensor的形状一致的。也就是舍弃掉index_tensor没有覆盖的位置。