Pytorch中的掩码:dtype=torch.uint8

在pytorch中,dtype=uint8的数据类型往往可以用作掩码0表示舍弃对应项1表示选取对应项。通过设置不同的0或1的值,对另外的tensor进行选择性选取:

例如:

t = torch.rand(42)
"""
tensor([[0.5492, 0.2083],
        [0.3635, 0.5198],
        [0.8294, 0.9869],
        [0.2987, 0.0279]])
"""

# 注意以下mask数据类型是 uint8 mask = torch.ones(4,dtype=torch.uint8) mask[2] = 0 print(mask) print(t[mask, :]) # 选取tensor t的第一个维度(由mask所在的位置决定的,这是numpy的花式索引的知识)的第0,1,3个行,以及这三个行对应的所有列;舍弃t的第2行。 """ tensor([1, 1, 0, 1], dtype=torch.uint8) # 因为是uint8类型的,所以当它被另外一个tensor当作索引时,1代表选取,0代表舍弃。uint8只能是0或者1 tensor([[0.5492, 0.2083], [0.3635, 0.5198], [0.2987, 0.0279]]) """ # 注意, 以下数据类型是long,可以和上面的uint8做一下对比 idx = torch.ones(3,dtype=torch.long) idx[1] = 0 print(idx) print(t[idx, :]) """ tensor([1, 0, 1]) # 因为是long类型的,所以当它被另外一个tensor当作索引时,1代表选取对应维度的第一个,0代表选取维度的第0个。当然也可以是其他的整数。 tensor([[0.3635, 0.5198], [0.5492, 0.2083], [0.3635, 0.5198]]) """
再举个例子:当mask掩码的形状和另一个tensor idx的形状相同,mask作为索引的时候,0或1直接相当于舍去或选取idx相应位置的值

 

 

 
posted @ 2021-12-05 20:07  cold_moon的笔记  阅读(879)  评论(0编辑  收藏  举报