RuntimeError: scatter(): Expected dtype int64 for index
RuntimeError: scatter(): Expected dtype int64 for index
跑代码时出现报错
RuntimeError: scatter(): Expected dtype int64 for index
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
问题出在scatter的函数参数上:第2个(算上self是第3个)参数应该是tensor,且其包含的元素类型应该是Long型,而原来程序内是Int型。
应该改为:
true_dist.scatter_(1, target.data.unsqueeze(1).long(), self.confidence)
转化为long型.