pytorch gather函数
转载于:https://www.zhihu.com/question/562282138/answer/2947708508?utm_id=0
官方文档链接:
https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
torch.gather()的定义非常简洁:
在指定dim上,从原tensor中获取指定index的数据, 看到这个核心定义,我们很容易想到gather()的基本想法就是从完整数据中按索引取值,比如下面从列表中按索引取值:
lst = [1, 2, 3, 4, 5]
value = lst[2] # value = 3
value = lst[2:4] # value = [3, 4]
上面的取值例子是取单个值或具有逻辑顺序序列的例子。
对于深度学习常用的批量tensor数据,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor
中取出指定乱序索引下的数据,因此其用途如下:
方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的。
实验
ex0 输入行向量index,并替换行索引(dim=0):
import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_1 = tensor_0.gather(0, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor1")
print(tensor_1)
#输出如下:
====>> tensor0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
====>> tensor1
tensor([[9, 7, 5]])
过程:
ex1 输入行向量index,并替换列索引(dim=1)
import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_2 = tensor_0.gather(1, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor2")
print(tensor_2)
输出:
====>> tensor0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
====>> tensor2
tensor([[5, 4, 3]])
ex2 输入行向量index,并替换列索引(dim=1)
index = torch.tensor
([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],
[7],
[9]])
ex3 输入二维矩阵index,并替换列索引(dim=1)
index = torch.tensor([[0, 2],
[1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],
[7, 8]])
要点
归纳出torch.gather()的使用要点
输出value的shape等于输入index的shape
索引input时,其索引构成过程:对输入index中的每个value的索引,只在对应的dim上将该索引的索引值替换为输入index中的对应value,就构成了对input的索引
用得到的input的索引,对input进行索引得到输出value
其他应用示例, 在mae的代码中,
如上代码两次argsort代码示例:
import torch
noise = torch.rand(3, 5)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
print(noise)
print(ids_shuffle)
print(ids_restore)
#输出如下:
tensor([[0.8787, 0.3496, 0.4642, 0.1852, 0.2965],
[0.0701, 0.1533, 0.1716, 0.1579, 0.5323],
[0.0827, 0.5038, 0.4169, 0.1121, 0.9830]])
tensor([[3, 4, 1, 2, 0],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])
tensor([[4, 2, 3, 0, 1],
[0, 1, 3, 2, 4],
[0, 3, 2, 1, 4]])
gather mae中的用法
import torch
D = 8
x = torch.randint(0, 20, (3, 5, D)) #[3, 5, 8]
noise = torch.randint(0, 20, (3, 5)) #[3, 5]
ids_shuffle = torch.argsort(noise, dim=1) #[3, 5]
ids_restore = torch.argsort(ids_shuffle, dim=1) #[3, 5]
len_keep = 2
ids_keep = ids_shuffle[:, :len_keep] #[3, 2]
index = ids_keep.unsqueeze(-1).repeat(1, 1, D) #[3, 2, 8]
x_masked = torch.gather(x, dim=1, index=index) #[3, 2, 8]
print("====>>> x")
print(x)
print("====>>> noise")
print(noise)
print("====>>> ids_shuffle")
print(ids_shuffle)
print("====>>> ids_keep.unsqueeze(-1)")
print(ids_keep.unsqueeze(-1))
print("====>>> ids_keep")
print(ids_keep)
print("====>>> index")
print(index)
print("====>>> x_masked")
print(x_masked)
输出如下:
====>>> x
tensor([[[13, 6, 7, 15, 1, 9, 7, 17],
[15, 15, 11, 15, 17, 4, 6, 15],
[10, 18, 5, 6, 18, 10, 19, 2],
[11, 19, 19, 11, 10, 11, 7, 11],
[18, 15, 17, 5, 7, 5, 9, 5]],
[[ 4, 12, 5, 7, 12, 15, 14, 6],
[15, 12, 13, 14, 8, 5, 15, 11],
[12, 17, 12, 11, 2, 9, 8, 1],
[18, 9, 6, 12, 19, 17, 10, 3],
[11, 4, 9, 18, 1, 17, 0, 10]],
[[18, 5, 11, 18, 19, 6, 0, 19],
[19, 15, 12, 9, 18, 3, 18, 1],
[15, 3, 17, 15, 3, 16, 0, 6],
[ 1, 4, 12, 10, 4, 10, 10, 4],
[18, 13, 3, 16, 1, 2, 15, 17]]])
====>>> noise
tensor([[ 8, 16, 16, 4, 17],
[ 0, 13, 4, 19, 17],
[14, 17, 1, 9, 4]])
====>>> ids_shuffle
tensor([[3, 0, 1, 2, 4],
[0, 2, 1, 4, 3],
[2, 4, 3, 0, 1]])
====>>> ids_keep.unsqueeze(-1)
tensor([[[3],
[0]],
[[0],
[2]],
[[2],
[4]]])
====>>> ids_keep
tensor([[3, 0],
[0, 2],
[2, 4]])
====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2, 2, 2]],
[[2, 2, 2, 2, 2, 2, 2, 2],
[4, 4, 4, 4, 4, 4, 4, 4]]])
====>>> x_masked
tensor([[[11, 19, 19, 11, 10, 11, 7, 11],
[13, 6, 7, 15, 1, 9, 7, 17]],
[[ 4, 12, 5, 7, 12, 15, 14, 6],
[12, 17, 12, 11, 2, 9, 8, 1]],
[[15, 3, 17, 15, 3, 16, 0, 6],
[18, 13, 3, 16, 1, 2, 15, 17]]])
解析:
x = torch.randint(0, 20, (3, 5, D)) #[3, 5, 8]
ids_keep = ids_shuffle[:, :len_keep] #[3, 2]
index = ids_keep.unsqueeze(-1).repeat(1, 1, D) #[3, 2, 8]
x_masked = torch.gather(x, dim=1, index=index) #[3, 2, 8]
为了取x[3, 5, 8]上的对应特征,ids_keep只取2个最小的,ids_keep[3,2]是存放的最小的索引,取值过程:
index[3,2,8]
(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7)
(0,1,0),(0,1,1),(0,1,2),(0,1,3),(0,1,4),(0,1,5),(0,1,6),(0,1,7)
(1,0,0),(1,0,1),(1,0,2),(1,0,3),(1,0,4),(1,0,5),(1,0,6),(1,0,7)
(1,1,0),(1,1,1),(1,1,2),(1,1,3),(1,1,4),(1,1,5),(1,1,6),(1,1,7)
(2,0,0),(2,0,1),(2,0,2),(2,0,3),(2,0,4),(2,0,5),(2,0,6),(2,0,7)
(2,1,0),(2,1,1),(2,1,2),(2,1,3),(2,1,4),(2,1,5),(2,1,6),(2,1,7)
====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2, 2, 2]],
[[2, 2, 2, 2, 2, 2, 2, 2],
[4, 4, 4, 4, 4, 4, 4, 4]]])
所以替换上面的dim=1维,
(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7) --->>>> (0,3,0),(0,3,1),(0,3,2),(0,3,3),(0,3,4),(0,3,5),(0,3,6),(0,3,7)
(0,1,0),(0,1,1),(0,1,2),(0,1,3),(0,1,4),(0,1,5),(0,1,6),(0,1,7) --->>>>(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7)
还有4行以此类推