`torch.gather`理解
函数定义
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
沿着dim指定的轴聚集tensor的值。返回的是原数据的复制,修改返回值不会修改原tensor。
参数:
- input: 原tensor
- dim: 待索引的轴
- index: 待聚集元素的索引
直观一点说就是:获取tensor中指定dim和指定index的数据,index可以不连续,dim只能指定为单个轴。
图解
可参照知乎 - 图解PyTorch中的torch.gather函数
代码实战
问题:torch.gather
和中括号索引有啥区别吗?换言之,有啥功能torch.gather
可以实现而中括号索引做不到的吗?
看几个例子(注意下文语言描述中index也是从0开始)
- 功能1:一维数组顺序索引,x是长度为8的tensor,取出其中2,3,4,5元素
import torch
x = torch.rand(8)
# 中括号索引
x1 = x[2:6]
# torch.gather
idx = torch.arange(2, 6)
x2 = x.gather(dim=0, index=idx)
print(x1.equal(x2))
print(x2)
# >> True
# >> tensor([0.2986, 0.9610, 0.5088, 0.5334])
- 功能2:一维数组乱序索引,x是长度为8的tensor,取出其中第3,2,1,4元素
import torch
x = torch.rand(8)
# 中括号索引
x1 = x[[3, 2, 1, 4]]
# torch.gather
idx = torch.tensor([3, 2, 1, 4])
x2 = x.gather(dim=0, index=idx)
print(x1.equal(x2))
print(x2)
# >> True
# >> tensor([0.2344, 0.3249, 0.6847, 0.0074])
- 功能3:二维数组,取出左上角矩阵块 (32, 32) -> (16, 16)
import torch
img = torch.rand(32, 32)
# 中括号索引
x1 = img[:16, :16]
# torch.gather
# torch.gather cannot do that, because it only gather from single axis
- 功能4:二维数组,取出其中一维数组的前5个元素
import torch
x = torch.rand(10, 8) # 理解为10个长度为8的一维数组
# 中括号索引
x1 = x[:, :5]
# torch.gather
idx = torch.arange(5)
x2 = x.gather(dim=1, index=idx.repeat(10, 1))
print(x1.equal(x2))
print(x2)
# >> True
# >> tensor([[0.6874, 0.0678, 0.9632, 0.1192, 0.6583],
[0.4384, 0.2263, 0.7262, 0.1914, 0.5774],
[0.1143, 0.4723, 0.2176, 0.6535, 0.3592],
[0.6786, 0.9794, 0.3704, 0.2499, 0.3386],
[0.2688, 0.0812, 0.1744, 0.7484, 0.4401],
[0.1044, 0.1304, 0.1224, 0.7055, 0.8579],
[0.5830, 0.8599, 0.2381, 0.0195, 0.0563],
[0.9367, 0.5019, 0.7067, 0.4395, 0.5474],
[0.6782, 0.0398, 0.1375, 0.7691, 0.2615],
[0.6938, 0.3334, 0.8047, 0.6111, 0.0039]])
- 功能5:二维数组shape=(3, 6),取出其中每个一维数组2个元素,数组1对应位置0、1,数组2对应位置2、3,数组3对应位置4、5
import torch
x = torch.rand(3, 6) # 理解为10个长度为8的一维数组
# 中括号索引???
# torch.gather
idx = torch.tensor([[0, 1], [2, 3], [4, 5]])
# idx = torch.arange(6).reshape(2, 3)
x2 = x.gather(dim=1, index=idx)
print(x)
print(idx)
print(x2)
# >> tensor([[0.9728, 0.8356, 0.0183, 0.7821, 0.8426, 0.1422],
# >> [0.3964, 0.4667, 0.3980, 0.3452, 0.3055, 0.8527],
# >> [0.2162, 0.5601, 0.4261, 0.1134, 0.0281, 0.4682]])
# >> tensor([[0, 1],
# >> [2, 3],
# >> [4, 5]])
# >> tensor([[0.9728, 0.8356],
# >> [0.3980, 0.3452],
# >> [0.0281, 0.4682]])
总结
torch.gather
仅能在一个维度上索引,对于一批数组,可以检索不同位置的元素;中括号索引可以在多个维度上操作,但对于一批数组,只能获取相同位置的元素。