Pytorch中的掩码:dtype=torch.uint8
在pytorch中,dtype=uint8的数据类型往往可以用作掩码,0表示舍弃对应项,1表示选取对应项。通过设置不同的0或1的值,对另外的tensor进行选择性选取:
例如:
t = torch.rand(4,2)
"""
tensor([[0.5492, 0.2083],
[0.3635, 0.5198],
[0.8294, 0.9869],
[0.2987, 0.0279]])
"""
# 注意以下mask数据类型是 uint8
mask = torch.ones(4,dtype=torch.uint8)
mask[2] = 0
print(mask)
print(t[mask, :]) # 选取tensor t的第一个维度(由mask所在的位置决定的,这是numpy的花式索引的知识)的第0,1,3个行,以及这三个行对应的所有列;舍弃t的第2行。
"""
tensor([1, 1, 0, 1], dtype=torch.uint8) # 因为是uint8类型的,所以当它被另外一个tensor当作索引时,1代表选取,0代表舍弃。uint8只能是0或者1
tensor([[0.5492, 0.2083],
[0.3635, 0.5198],
[0.2987, 0.0279]])
"""
# 注意, 以下数据类型是long,可以和上面的uint8做一下对比
idx = torch.ones(3,dtype=torch.long)
idx[1] = 0
print(idx)
print(t[idx, :])
"""
tensor([1, 0, 1]) # 因为是long类型的,所以当它被另外一个tensor当作索引时,1代表选取对应维度的第一个,0代表选取维度的第0个。当然也可以是其他的整数。
tensor([[0.3635, 0.5198],
[0.5492, 0.2083],
[0.3635, 0.5198]])
"""
再举个例子:当mask掩码的形状和另一个tensor idx的形状相同,mask作为索引的时候,0或1直接相当于舍去或选取idx相应位置的值