一、背景
在使用torch的时候,可以通过bool类型对数组进行检索操作。传统的list或者dict都是使用下标和关键字检索。而在torch中可以使用bool类型进行检索,它的的目标主要是以下功能:
- 替换torch中的某个值
二、使用
torch在bool检索的情况下就是将为检索位置为True的地方用另一个数据进行替换。
import torch
x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x), x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices] # 将True的部分进行修改
print(labels)
# output:
"""
masked_indices第四个位置为True,因此修改labels中第四个位置,由于噪声数据第四个的位置是1,因此labels中的数据为1
tensor([3, 2, 0, 1, 1])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, False, True])
tensor([1., 2., 3., 4., 1.])
"""
import torch
x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x)+1999, x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]
print(labels)
# output:
"""
tensor([1516, 408, 274, 426, 126])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, True, False])
tensor([ 1., 2., 3., 426., 5.])
"""