tensor索引与numpy类似,支持冒号,和数字直接索引

import torch

a = torch.Tensor(2, 3, 4)
a
# 输出:
      tensor([[[9.2755e-39, 1.0561e-38, 9.7347e-39, 1.1112e-38],
             [1.0194e-38, 8.4490e-39, 1.0102e-38, 9.0919e-39],
             [1.0102e-38, 8.9082e-39, 8.4489e-39, 1.0102e-38]],
    
            [[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
             [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
             [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]]])

# 冒号索引与数字索引
a[:1, :2, 1]
# 输出:
      tensor([[1.0561e-38, 8.4490e-39]])

# 通过-1索引
a[-1]
# 输出:
      tensor([[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
            [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
            [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]])

...(三个点)索引

用于维度过多,且取中间多个维度所有数据的情况

# 生成多维数据
a = torch.rand(1,2,3,2,4,5)
a
# 输出:
     tensor([[[[[[0.1954, 0.1918, 0.3053, 0.3649, 0.3637],
                [0.8467, 0.0205, 0.2187, 0.8438, 0.1754],
                [0.7076, 0.7047, 0.1852, 0.5374, 0.7024],
                [0.5630, 0.4526, 0.0662, 0.9463, 0.9294]],
    
               [[0.6917, 0.5505, 0.5770, 0.3819, 0.9541],
                [0.8957, 0.2530, 0.4858, 0.1866, 0.2542],
                [0.3745, 0.2125, 0.5537, 0.5642, 0.2284],
                [0.2634, 0.1147, 0.1793, 0.0277, 0.9800]]], 

              ...

              [[[0.9949, 0.2210, 0.3365, 0.0852, 0.4387],
                [0.6440, 0.6391, 0.9141, 0.2288, 0.6203],
                [0.0474, 0.7894, 0.4362, 0.9752, 0.7546],
                [0.1234, 0.0246, 0.1436, 0.0053, 0.3405]],
    
               [[0.8174, 0.9021, 0.0420, 0.2045, 0.2140],
                [0.4844, 0.6342, 0.2965, 0.9299, 0.2284],
                [0.1420, 0.1834, 0.0581, 0.8467, 0.8987],
                [0.8012, 0.1526, 0.4293, 0.3928, 0.5437]]]]]]) 

# 取第一维和最后一维的0索引数据,中间所有维度数据全部取出
a[0, ..., 0]
# 输出:
      tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
              [0.6917, 0.8957, 0.3745, 0.2634]],
    
             [[0.4374, 0.0534, 0.6809, 0.7086],
              [0.2231, 0.6680, 0.8643, 0.9057]],
    
             [[0.8169, 0.0649, 0.5923, 0.3802],
              [0.2562, 0.0095, 0.8557, 0.6828]]],
    
    
            [[[0.1514, 0.3948, 0.6452, 0.6332],
              [0.8872, 0.7304, 0.6853, 0.9814]],
    
             [[0.5736, 0.5195, 0.9711, 0.5575],
              [0.6778, 0.9334, 0.5647, 0.1006]],
    
             [[0.9949, 0.6440, 0.0474, 0.1234],
              [0.8174, 0.4844, 0.1420, 0.8012]]]])

# 上面等价于
a[0,:,:,:,:,0]
# 输出:
      tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
              [0.6917, 0.8957, 0.3745, 0.2634]],
    
             [[0.4374, 0.0534, 0.6809, 0.7086],
              [0.2231, 0.6680, 0.8643, 0.9057]],
    
             [[0.8169, 0.0649, 0.5923, 0.3802],
              [0.2562, 0.0095, 0.8557, 0.6828]]],
    
    
            [[[0.1514, 0.3948, 0.6452, 0.6332],
              [0.8872, 0.7304, 0.6853, 0.9814]],
    
             [[0.5736, 0.5195, 0.9711, 0.5575],
              [0.6778, 0.9334, 0.5647, 0.1006]],
    
             [[0.9949, 0.6440, 0.0474, 0.1234],
              [0.8174, 0.4844, 0.1420, 0.8012]]]])
可以看出,使用...可以节省操作。

masked_select

# 生成随机数据
a = torch.randn(3, 4)
a
# 输出:
    tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
            [ 0.4734, -0.7182, -0.1516,  0.0209],
            [ 0.5089, -0.8130, -0.4519, -0.6190]])

# 大于0.5的数据返回True
mask = a.ge(0.5)
mask
# 输出:
    tensor([[ True,  True, False, False],
            [False, False, False, False],
            [ True, False, False, False]])

# 通过上面生成的bool数据,利用masked_select来选择大于0.5的数据
torch.masked_select(a, mask)
# 输出:
    tensor([0.8710, 0.8862, 0.5089])  

take

a
# 输出:
      tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
            [ 0.4734, -0.7182, -0.1516,  0.0209],
            [ 0.5089, -0.8130, -0.4519, -0.6190]])

# 先将数据打平展开为一维,再选取展开后对应索引[0, 5, 8, 11]的数据
torch.take(a, torch.tensor([0, 5, 8, 11]))
# 输出:
      tensor([ 0.8710, -0.7182,  0.5089, -0.6190])
posted on 2020-05-31 16:19  jaysonteng  阅读(2672)  评论(0编辑  收藏  举报