Pytorch list tensor 转 onehot

def test_onehot():
    v = torch.tensor([[0.1, 0.2, 0.7],
                      [0.1, 0.6, 0.3],
                      [0.1, 0.5, 0.4],
                      [0.8, 0.1, 0.1], ])

    print('v', v.size(), v)
    # 按照形状创建全0张量
    result = torch.zeros_like(v, dtype=torch.long)
    # 目标维度
    dim = -1
    # 根据索引将值改为1
    result.scatter_(dim,
                    v.argmax(dim).unsqueeze(dim),
                    torch.ones(4, dtype=torch.long).unsqueeze(dim))

    print('result', result.size(), result)

输出

v torch.Size([4, 3]) tensor([[0.1000, 0.2000, 0.7000],
        [0.1000, 0.6000, 0.3000],
        [0.1000, 0.5000, 0.4000],
        [0.8000, 0.1000, 0.1000]])
result torch.Size([4, 3]) tensor([[0, 0, 1],
        [0, 1, 0],
        [0, 1, 0],
        [1, 0, 0]])

自用两个方法

def list_onehot(actions: list, n: int) -> torch.Tensor:
    """
    列表动作值转 onehot
    actions: 动作列表
    n: 动作总个数
    """
    result = []
    for action in actions:
        result.append([int(k == action) for k in range(n)])
    result = torch.tensor(result, dtype=torch.long)
    if torch.cuda.is_available():
        result = result.cuda()
    return result

def max_onehot(props: torch.Tensor, dim=-1) -> torch.Tensor:
    """
    动作概率 tensor 转 onehot
    props: 动作概率表
    dim: 目标维度
    """
    result = torch.zeros_like(props, dtype=torch.long)
    src = torch.ones(self.batchSize, dtype=torch.long).unsqueeze(dim)
    if torch.cuda.is_available():
        result = result.cuda()
        src = src.cuda()
    result.scatter_(dim, props.argmax(dim).unsqueeze(dim), src)
    return result
posted @ 2021-11-22 10:31  太晓  阅读(512)  评论(0编辑  收藏  举报