pytorch gather函数

转载于:https://www.zhihu.com/question/562282138/answer/2947708508?utm_id=0
官方文档链接:
https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather

torch.gather()的定义非常简洁:
在指定dim上,从原tensor中获取指定index的数据, 看到这个核心定义,我们很容易想到gather()的基本想法就是从完整数据中按索引取值,比如下面从列表中按索引取值:

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子。
对于深度学习常用的批量tensor数据,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor
中取出指定乱序索引下的数据,因此其用途如下:
方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的。

实验

ex0 输入行向量index,并替换行索引(dim=0):

import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]

tensor_1 = tensor_0.gather(0, index)
print("====>> tensor0")
print(tensor_0)
print("====>> tensor1")
print(tensor_1)

#输出如下:
====>> tensor0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
====>> tensor1
tensor([[9, 7, 5]])

过程:

ex1 输入行向量index,并替换列索引(dim=1)

import torch
tensor_0 = torch.arange(3, 12).view(3, 3) #[3, 3]
index = torch.tensor([[2, 1, 0]]) #[1, 3]
tensor_2 = tensor_0.gather(1, index)

print("====>> tensor0")
print(tensor_0)
print("====>> tensor2")
print(tensor_2)

输出:
====>> tensor0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
====>> tensor2
tensor([[5, 4, 3]])

ex2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor

([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

tensor([[5],
        [7],
        [9]])

ex3 输入二维矩阵index,并替换列索引(dim=1)

index = torch.tensor([[0, 2], 
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

tensor([[3, 5],
        [7, 8]])

要点

归纳出torch.gather()的使用要点

输出value的shape等于输入index的shape

索引input时,其索引构成过程:对输入index中的每个value的索引,只在对应的dim上将该索引的索引值替换为输入index中的对应value,就构成了对input的索引

用得到的input的索引,对input进行索引得到输出value

其他应用示例, 在mae的代码中,

https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L123

如上代码两次argsort代码示例:

import torch

noise = torch.rand(3, 5)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

print(noise)
print(ids_shuffle)
print(ids_restore)
#输出如下:
tensor([[0.8787, 0.3496, 0.4642, 0.1852, 0.2965],
        [0.0701, 0.1533, 0.1716, 0.1579, 0.5323],
        [0.0827, 0.5038, 0.4169, 0.1121, 0.9830]])
tensor([[3, 4, 1, 2, 0],
        [0, 1, 3, 2, 4],
        [0, 3, 2, 1, 4]])
tensor([[4, 2, 3, 0, 1],
        [0, 1, 3, 2, 4],
        [0, 3, 2, 1, 4]])

gather mae中的用法

import torch

D = 8
x = torch.randint(0, 20, (3, 5, D)) #[3, 5, 8]

noise = torch.randint(0, 20, (3, 5)) #[3, 5]
ids_shuffle = torch.argsort(noise, dim=1) #[3, 5]
ids_restore = torch.argsort(ids_shuffle, dim=1) #[3, 5]

len_keep = 2
ids_keep = ids_shuffle[:, :len_keep] #[3, 2]

index = ids_keep.unsqueeze(-1).repeat(1, 1, D) #[3, 2, 8]
x_masked = torch.gather(x, dim=1, index=index) #[3, 2, 8]

print("====>>> x")
print(x)

print("====>>> noise")
print(noise)

print("====>>> ids_shuffle")
print(ids_shuffle)

print("====>>> ids_keep.unsqueeze(-1)")
print(ids_keep.unsqueeze(-1))

print("====>>> ids_keep")
print(ids_keep)

print("====>>> index")
print(index)

print("====>>> x_masked")
print(x_masked)

输出如下:

====>>> x
tensor([[[13,  6,  7, 15,  1,  9,  7, 17],
         [15, 15, 11, 15, 17,  4,  6, 15],
         [10, 18,  5,  6, 18, 10, 19,  2],
         [11, 19, 19, 11, 10, 11,  7, 11],
         [18, 15, 17,  5,  7,  5,  9,  5]],

        [[ 4, 12,  5,  7, 12, 15, 14,  6],
         [15, 12, 13, 14,  8,  5, 15, 11],
         [12, 17, 12, 11,  2,  9,  8,  1],
         [18,  9,  6, 12, 19, 17, 10,  3],
         [11,  4,  9, 18,  1, 17,  0, 10]],

        [[18,  5, 11, 18, 19,  6,  0, 19],
         [19, 15, 12,  9, 18,  3, 18,  1],
         [15,  3, 17, 15,  3, 16,  0,  6],
         [ 1,  4, 12, 10,  4, 10, 10,  4],
         [18, 13,  3, 16,  1,  2, 15, 17]]])
====>>> noise
tensor([[ 8, 16, 16,  4, 17],
        [ 0, 13,  4, 19, 17],
        [14, 17,  1,  9,  4]])
====>>> ids_shuffle
tensor([[3, 0, 1, 2, 4],
        [0, 2, 1, 4, 3],
        [2, 4, 3, 0, 1]])
====>>> ids_keep.unsqueeze(-1)
tensor([[[3],
         [0]],

        [[0],
         [2]],

        [[2],
         [4]]])
====>>> ids_keep
tensor([[3, 0],
        [0, 2],
        [2, 4]])
====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 2, 2, 2]],

        [[2, 2, 2, 2, 2, 2, 2, 2],
         [4, 4, 4, 4, 4, 4, 4, 4]]])
====>>> x_masked
tensor([[[11, 19, 19, 11, 10, 11,  7, 11],
         [13,  6,  7, 15,  1,  9,  7, 17]],

        [[ 4, 12,  5,  7, 12, 15, 14,  6],
         [12, 17, 12, 11,  2,  9,  8,  1]],

        [[15,  3, 17, 15,  3, 16,  0,  6],
         [18, 13,  3, 16,  1,  2, 15, 17]]])

解析:
x = torch.randint(0, 20, (3, 5, D)) #[3, 5, 8]
ids_keep = ids_shuffle[:, :len_keep] #[3, 2]
index = ids_keep.unsqueeze(-1).repeat(1, 1, D) #[3, 2, 8]
x_masked = torch.gather(x, dim=1, index=index) #[3, 2, 8]
为了取x[3, 5, 8]上的对应特征,ids_keep只取2个最小的,ids_keep[3,2]是存放的最小的索引,取值过程:
index[3,2,8]
(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7)
(0,1,0),(0,1,1),(0,1,2),(0,1,3),(0,1,4),(0,1,5),(0,1,6),(0,1,7)
(1,0,0),(1,0,1),(1,0,2),(1,0,3),(1,0,4),(1,0,5),(1,0,6),(1,0,7)
(1,1,0),(1,1,1),(1,1,2),(1,1,3),(1,1,4),(1,1,5),(1,1,6),(1,1,7)
(2,0,0),(2,0,1),(2,0,2),(2,0,3),(2,0,4),(2,0,5),(2,0,6),(2,0,7)
(2,1,0),(2,1,1),(2,1,2),(2,1,3),(2,1,4),(2,1,5),(2,1,6),(2,1,7)

====>>> index
tensor([[[3, 3, 3, 3, 3, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 0, 0]],

    [[0, 0, 0, 0, 0, 0, 0, 0],
     [2, 2, 2, 2, 2, 2, 2, 2]],

    [[2, 2, 2, 2, 2, 2, 2, 2],
     [4, 4, 4, 4, 4, 4, 4, 4]]])

所以替换上面的dim=1维,
(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7) --->>>> (0,3,0),(0,3,1),(0,3,2),(0,3,3),(0,3,4),(0,3,5),(0,3,6),(0,3,7)
(0,1,0),(0,1,1),(0,1,2),(0,1,3),(0,1,4),(0,1,5),(0,1,6),(0,1,7) --->>>>(0,0,0),(0,0,1),(0,0,2),(0,0,3),(0,0,4),(0,0,5),(0,0,6),(0,0,7)
还有4行以此类推

posted @ 2024-02-04 21:20  无左无右  阅读(321)  评论(0编辑  收藏  举报