pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案
RuntimeError: Error(s) in loading state_dict for Sequential:
该错误通常与使用了nn.DataParallel进行训练有关
是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配
因此,我们只需要修改torch.load获取的dict,令其匹配。
例如:
我torch.save时,参数key中字符串前自动添加了'module.'
因此,在torch.load后,需要去掉'module.'
方法如下:
model = Model() model_para_dict_temp = torch.load('xxx.pth') model_para_dict = {} for key_i in model_para_dict_temp.keys(): model_para_dict[key_i[7:]] = model_para_dict_temp[key_i] # 删除掉前7个字符'module.' del model_para_dict_temp model.load_state_dict(model_para_dict)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 上周热点回顾(3.3-3.9)
· AI 智能体引爆开源社区「GitHub 热点速览」
· 写一个简单的SQL生成工具