CrossEntropyLoss: RuntimeError: expected scalar type Float but found Long neural network
错误分析
这个错误通常指的是期望接受的参数类型是Float
, 但是程序员传入的是Int
。 通常会需要我们去检查传入的 input
和 target
的数据类型有没有匹配。在传入的数据中,通常 input
希望是 Float
类型,target
是 Int
类型。
但是通常也许会发现传入的参数是符合要求的,但是仍然会报这样的错误,那么这个时候就需要注意查看 CrossEntropyLoss
中传入的参数 weight
的类型,传入的参数weight
也必须是一个浮点数,即,如果你设置成 [1, 2]
也必须写成 [1.0, 2.0]
的形式。
样例
CorssEntropyLoss
的参数使用的样例代码如下:
class_weights = torch.tensor([1.0, 2.0], device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
通常, 这个参数是我们在做分类任务时,当我们期待对少数类样本投以更多关注时就可以开始设置,在异常检测的领域比较常见。
本文作者:tjdtec
本文链接:https://www.cnblogs.com/tjdtec/p/17860771.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步