pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案

RuntimeError: Error(s) in loading state_dict for Sequential:

该错误通常与使用了nn.DataParallel进行训练有关

是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配

因此,我们只需要修改torch.load获取的dict,令其匹配。

 

例如:

我torch.save时,参数key中字符串前自动添加了'module.'

因此,在torch.load后,需要去掉'module.'

方法如下:

model = Model()
model_para_dict_temp = torch.load('xxx.pth')
model_para_dict = {}
for key_i in model_para_dict_temp.keys():
    model_para_dict[key_i[7:]] = model_para_dict_temp[key_i]  # 删除掉前7个字符'module.'
del model_para_dict_temp
model.load_state_dict(model_para_dict)

 

posted @   Real_Tourist  阅读(2897)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 上周热点回顾(3.3-3.9)
· AI 智能体引爆开源社区「GitHub 热点速览」
· 写一个简单的SQL生成工具
点击右上角即可分享
微信分享提示