one_hot编码

 

src0 = torch.tensor([[ 0.], [1.], [2.], [3.], [4.], [5.], [1.], [2.], [3.], [3.], [0.], [1.], [4.]])

src = np.array(src0).squeeze()

torch.eye(6)[src,:]
 
 
 

 

posted @ 2021-09-17 01:14  呦呦南山  阅读(23)  评论(0编辑  收藏  举报