Python报错 | RuntimeError: expected scalar type Long but found Float
报错信息
在执行nlp自定义模型的训练函数的时候,报如下错误:
RuntimeError: expected scalar type Float but found Long
错误原因
错误信息指出了问题所在:模型期望的数据类型是 float,但实际上传递给模型的数据类型是 long。
这个错误通常是由于张量数据类型不匹配引起的。在 PyTorch 中,张量数据类型非常重要,因为它们指定了张量中存储的数值的精度和类型。如果您在模型的前向传递中使用了错误的数据类型,就会出现这个错误。
例如:
import torch
import torch.nn as nn
v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)
运行结果:
因为input也就是我们的v是torch.long
类型的而weight是torch.float
类型。所以在做矩阵乘法的时候这两种类型的不一致导致了报错。
解决方案
把v的dtype显示地设置成torch.float代码就成功运行了
import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)
运行结果:
tensor([-0.6189, -0.9843, -0.7568, 0.9157, 0.5192, -0.6109, -0.5627, -0.7755,
-0.9522, 0.7771], grad_fn=<AddBackward0>)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)