pytorch框架学习之gather()方法的详解
给一个张量
索引 | index0 | index1 | index2 | index3 |
---|---|---|---|---|
index0 | 0 | 1 | 2 | 3 |
index1 | 4 | 5 | 6 | 7 |
index2 | 8 | 9 | 10 | 11 |
index3 | 12 | 13 | 14 | 15 |
torch.gather(dim, index) → Tensor
首先dim表示维度,如果dim=0,且是二维矩阵,则是固定列的顺序为0123~n,而行号需要通过输入的矩阵的值来确定。
这里的index输入的是一个矩阵,gather获得的矩阵形状和index传入的矩阵是一致的。
通过例子来解释说明:
针对dim=0,就是行号变动。
>>> import torch as t
>>> a = t.arange(0,16).view(4,4)
>>> a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
# 选取反对角线上的元素,注意与上面的不同
>>> index = t.LongTensor([[3,2,1,0]])
>>> a.gather(0,index)
tensor([[12, 9, 6, 3]])
概念如下固定住1维就是列号的顺序递增不动,而行号则由index矩阵相应位置给出,按照index = tensor([[3, 2, 1, 0]])
顺序作用在行上索引依次为3,2,1,0。
a[3][0] = 12 a[2][1] = 9 a[1][2] = 6 a[0][3] = 3
针对dim=1,就是列号变动。
>>> index = t.LongTensor([[3,2,1,0]]).t()
>>> a.gather(1,index)
tensor([[ 3],
[ 6],
[ 9],
[12]])
a[0][3] = 3 a[1][2] = 6 a[2][1] = 9 a[3][0] = 12
同样改变的是列号位置。
当然除了二维的情况,大多数dim还会更高,也可以采用同样地方式去理解,但是首先需要明白,目标矩阵的形式和index所给的矩阵形式是一样的。