『PyTorch』第五弹_深入理解Tensor对象_中上:索引
一、普通索引
示例
a = t.Tensor(4,5) print(a) print(a[0:1,:2]) print(a[0,:2]) # 注意和前一种索引出来的值相同,shape不同 print(a[[1,2]]) # 容器索引
3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 0.0000e+00 1.5035e+38 8.5479e-43 1.5134e-43 1.2612e-41 [torch.FloatTensor of size 4x5] 3.3845e+15 0.0000e+00 [torch.FloatTensor of size 1x2] 3.3845e+15 0.0000e+00 [torch.FloatTensor of size 2] 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 [torch.FloatTensor of size 2x5]
普通索引内存分析
普通索引后的结果和原Tensor的内存共享
print(a[a>1]) import copy b = copy.deepcopy(a) a[a>1]=10 print(a,b)
3.3845e+15 3.3846e+15 3.3845e+15 3.3845e+15 3.3418e+15 3.3845e+15 3.3846e+15 1.5035e+38 [torch.FloatTensor of size 8] 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 0.0000 0.0000 10.0000 0.0000 0.0000 0.0000 [torch.FloatTensor of size 4x5] 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 0.0000e+00 1.5035e+38 8.5479e-43 1.5134e-43 1.2612e-41 [torch.FloatTensor of size 4x5]
索引函数gather介绍
方的介绍:
如果input是一个n维的tensor,size为
(x0,x1…,xi−1,xi,xi+1,…,xn−1),dim为i,然后index必须也为n维tensor,size为
(x0,x1,…,xi−1,y,xi+1,…,xn−1),其中y >= 1,最后输出的out与index的size是一样的。
意思就是按照一个指定的轴(维数)收集值
对于一个三维向量来说:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
参数:
input (Tensor) – 源tensor
dim (int) – 指定的轴数(维数)
index (LongTensor) – 需要聚集起来的数据的索引
out (Tensor, optional) – 目标tensor
简单来说,就是在Tensor(input)的众多维度中针对某一维度(dim参数),使用一维Tensor(index)进行索引,并对其他维度进行遍历。
a = t.arange(16).view(4,4) index = t.LongTensor([[0,1,2,3]]) print(a) print(index) print(a.gather(0,index)) # 逆操作scatter_,注意是inplace的 b = t.zeros(4,4) b.scatter_(0,index,a.gather(0,index)) print(b)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 [torch.FloatTensor of size 4x4] 0 1 2 3 [torch.LongTensor of size 1x4] 0 5 10 15 [torch.FloatTensor of size 1x4] 0 0 0 0 0 5 0 0 0 0 10 0 0 0 0 15 [torch.FloatTensor of size 4x4]
二、高阶索引
和普通索引不同,高阶索引前后一般不会共享内存,后面介绍Tensor内存结构时会提到。
x = t.arange(0,27).view(3,3,3) print(x) print(x[[1,2],[1,2],[2,0]]) # x[1,1,2]和x[2,2,0] print(x[[2,1,0],[0],[0]]) # x[2,0,0]和x[1,0,0]和x[0,0,0]
(0 ,.,.) = 0 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 17 (2 ,.,.) = 18 19 20 21 22 23 24 25 26 [torch.FloatTensor of size 3x3x3] 14 24 [torch.FloatTensor of size 2] 18 9 0 [torch.FloatTensor of size 3]