pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案
RuntimeError: Error(s) in loading state_dict for Sequential:
该错误通常与使用了nn.DataParallel进行训练有关
是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配
因此,我们只需要修改torch.load获取的dict,令其匹配。
例如:
我torch.save时,参数key中字符串前自动添加了'module.'
因此,在torch.load后,需要去掉'module.'
方法如下:
model = Model() model_para_dict_temp = torch.load('xxx.pth') model_para_dict = {} for key_i in model_para_dict_temp.keys(): model_para_dict[key_i[7:]] = model_para_dict_temp[key_i] # 删除掉前7个字符'module.' del model_para_dict_temp model.load_state_dict(model_para_dict)