Pytorch list tensor 转 onehot
def test_onehot():
v = torch.tensor([[0.1, 0.2, 0.7],
[0.1, 0.6, 0.3],
[0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], ])
print('v', v.size(), v)
# 按照形状创建全0张量
result = torch.zeros_like(v, dtype=torch.long)
# 目标维度
dim = -1
# 根据索引将值改为1
result.scatter_(dim,
v.argmax(dim).unsqueeze(dim),
torch.ones(4, dtype=torch.long).unsqueeze(dim))
print('result', result.size(), result)
输出
v torch.Size([4, 3]) tensor([[0.1000, 0.2000, 0.7000],
[0.1000, 0.6000, 0.3000],
[0.1000, 0.5000, 0.4000],
[0.8000, 0.1000, 0.1000]])
result torch.Size([4, 3]) tensor([[0, 0, 1],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0]])
自用两个方法
def list_onehot(actions: list, n: int) -> torch.Tensor:
"""
列表动作值转 onehot
actions: 动作列表
n: 动作总个数
"""
result = []
for action in actions:
result.append([int(k == action) for k in range(n)])
result = torch.tensor(result, dtype=torch.long)
if torch.cuda.is_available():
result = result.cuda()
return result
def max_onehot(props: torch.Tensor, dim=-1) -> torch.Tensor:
"""
动作概率 tensor 转 onehot
props: 动作概率表
dim: 目标维度
"""
result = torch.zeros_like(props, dtype=torch.long)
src = torch.ones(self.batchSize, dtype=torch.long).unsqueeze(dim)
if torch.cuda.is_available():
result = result.cuda()
src = src.cuda()
result.scatter_(dim, props.argmax(dim).unsqueeze(dim), src)
return result