RuntimeError: scatter(): Expected dtype int64 for index

RuntimeError: scatter(): Expected dtype int64 for index

跑代码时出现报错

RuntimeError: scatter(): Expected dtype int64 for index
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)

问题出在scatter的函数参数上:第2个(算上self是第3个)参数应该是tensor,且其包含的元素类型应该是Long型,而原来程序内是Int型。

应该改为:

true_dist.scatter_(1, target.data.unsqueeze(1).long(), self.confidence)

转化为long型.

posted @ 2024-01-15 10:19  咖啡陪你  阅读(301)  评论(0编辑  收藏  举报