pytorch one hot 的转换

转onehot

one_hot = F.one_hot(label.long(), num_classes=n_classes)

转回来

label = torch.argmax(one_hot, -1)
posted @ 2021-01-12 10:20  mrbean  阅读(841)  评论(0编辑  收藏  举报