小灰灰深度学习之关于三维张量的一些索引
首先要感谢CSDN中http://t.csdn.cn/XyT4e这篇文章(我接下来写的内容,也和这篇文章基本一样)
下面是我实际操作得到的结果:
我们看第一种情况的代码:
import torch b = torch.arange(1, 61).reshape(3, 4, 5) idx1 = torch.tensor([0, 0, 2]).unsqueeze(-1).repeat(1, 4) bb = b[idx1, : print(bb)
我们先来看一下张量b的内容:
''' 张量b内容: tensor([[[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[21, 22, 23, 24, 25], [26, 27, 28, 29, 30], [31, 32, 33, 34, 35], [36, 37, 38, 39, 40]], [[41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60]]])) '''
接下来我们看一下索引得到的张量bb的内容(这个有点长):
''' 张量bb的内容为: tensor([[[[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]], [[[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]], [[[41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60]], [[41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60]], [[41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60]], [[41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60]]]]) '''
首先b是(3,4,5)的张量,然后b[idx1, :, ]这里索引就是将张量b的后面两个维度(4, 5)当作一个整体。然后根据idx1中的内容进行索引。又因为idx1的shape是(3, 4)
所以索引后的bb的shape为(3, 4, 4, 5)。然后也就得到了那个结果。
第二种情况的代码为(这里我们的张量b仍选用第一种情况的张量b):
import torch b = torch.arange(1, 61).reshape(3, 4, 5) idx1 = torch.tensor([0, 0, 2]).unsqueeze(-1).repeat(1, 4) idx2 = torch.randint(0, 4, (3, 4), dtype = torch.long) cc = b[idx1, idx2] cc.shape,cc
此时的输出结果为:
''' idx2得到的随机值为: tensor([[1, 2, 1, 3], [2, 1, 1, 2], [2, 1, 0, 3]]) cc.shape: torch.Size([3, 4, 5]) 张量cc为: tensor([[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10], [16, 17, 18, 19, 20]], [[11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10], [ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15]], [[51, 52, 53, 54, 55], [46, 47, 48, 49, 50], [41, 42, 43, 44, 45], [56, 57, 58, 59, 60]]]) '''
此时我们可以看到idx1与idx2他们都是(3,4)d的矩阵,所以对于b[idx1, idx2]会先将idx1与idx2组合起来然后作为索引,去索引张量b中的第三个维度[5]。然后就得到结果了。
1.如果我们要多维度的索引,我们需要保证dim(idx1) = dim(idx2)
2.多个维度一起索引时,我们先把两个维度叠加在一起,然后根据新的索引去索引