pytorch的argmax:只改变要挑选的维度

只改变要挑选的维度,其他维度不变

A=torch.tensor([[[3,4]]])
dec_X = A.argmax(dim=2)# 只在dim=2上挑选最大值,得到索引为scalar
dec_X #相当于把最内侧的[3,4]的维度去掉了,得到结果1。其他维度不变

tensor([[1]])

posted @ 2021-11-14 17:36  zae  阅读(139)  评论(0编辑  收藏  举报