`torch.gather`理解

official link

函数定义

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

沿着dim指定的轴聚集tensor的值。返回的是原数据的复制,修改返回值不会修改原tensor。
参数:

  • input: 原tensor
  • dim: 待索引的轴
  • index: 待聚集元素的索引

直观一点说就是:获取tensor中指定dim和指定index的数据,index可以不连续,dim只能指定为单个轴。

图解

可参照知乎 - 图解PyTorch中的torch.gather函数

代码实战

问题:torch.gather和中括号索引有啥区别吗?换言之,有啥功能torch.gather可以实现而中括号索引做不到的吗?
看几个例子(注意下文语言描述中index也是从0开始)

  • 功能1:一维数组顺序索引,x是长度为8的tensor,取出其中2,3,4,5元素
import torch

x = torch.rand(8)

# 中括号索引
x1 = x[2:6]

# torch.gather
idx = torch.arange(2, 6)
x2 = x.gather(dim=0, index=idx)

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([0.2986, 0.9610, 0.5088, 0.5334])
  • 功能2:一维数组乱序索引,x是长度为8的tensor,取出其中第3,2,1,4元素
import torch

x = torch.rand(8)

# 中括号索引
x1 = x[[3, 2, 1, 4]]

# torch.gather
idx = torch.tensor([3, 2, 1, 4])
x2 = x.gather(dim=0, index=idx)

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([0.2344, 0.3249, 0.6847, 0.0074])
  • 功能3:二维数组,取出左上角矩阵块 (32, 32) -> (16, 16)
import torch

img = torch.rand(32, 32)

# 中括号索引
x1 = img[:16, :16]

# torch.gather
# torch.gather cannot do that, because it only gather from single axis

  • 功能4:二维数组,取出其中一维数组的前5个元素
import torch

x = torch.rand(10, 8)  # 理解为10个长度为8的一维数组

# 中括号索引
x1 = x[:, :5]

# torch.gather
idx = torch.arange(5)
x2 = x.gather(dim=1, index=idx.repeat(10, 1))

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([[0.6874, 0.0678, 0.9632, 0.1192, 0.6583],
        [0.4384, 0.2263, 0.7262, 0.1914, 0.5774],
        [0.1143, 0.4723, 0.2176, 0.6535, 0.3592],
        [0.6786, 0.9794, 0.3704, 0.2499, 0.3386],
        [0.2688, 0.0812, 0.1744, 0.7484, 0.4401],
        [0.1044, 0.1304, 0.1224, 0.7055, 0.8579],
        [0.5830, 0.8599, 0.2381, 0.0195, 0.0563],
        [0.9367, 0.5019, 0.7067, 0.4395, 0.5474],
        [0.6782, 0.0398, 0.1375, 0.7691, 0.2615],
        [0.6938, 0.3334, 0.8047, 0.6111, 0.0039]])
  • 功能5:二维数组shape=(3, 6),取出其中每个一维数组2个元素,数组1对应位置0、1,数组2对应位置2、3,数组3对应位置4、5
import torch

x = torch.rand(3, 6)  # 理解为10个长度为8的一维数组

# 中括号索引???

# torch.gather

idx = torch.tensor([[0, 1], [2, 3], [4, 5]])
# idx = torch.arange(6).reshape(2, 3)

x2 = x.gather(dim=1, index=idx)

print(x)
print(idx)
print(x2)

# >> tensor([[0.9728, 0.8356, 0.0183, 0.7821, 0.8426, 0.1422],
# >>         [0.3964, 0.4667, 0.3980, 0.3452, 0.3055, 0.8527],
# >>         [0.2162, 0.5601, 0.4261, 0.1134, 0.0281, 0.4682]])
# >> tensor([[0, 1],
# >>         [2, 3],
# >>         [4, 5]])
# >> tensor([[0.9728, 0.8356],
# >>         [0.3980, 0.3452],
# >>         [0.0281, 0.4682]])

总结

torch.gather仅能在一个维度上索引,对于一批数组,可以检索不同位置的元素;中括号索引可以在多个维度上操作,但对于一批数组,只能获取相同位置的元素。

posted @ 2022-03-23 13:56  Js2Hou  阅读(385)  评论(0编辑  收藏  举报