pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案

RuntimeError: Error(s) in loading state_dict for Sequential:

该错误通常与使用了nn.DataParallel进行训练有关

是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配

因此,我们只需要修改torch.load获取的dict,令其匹配。

 

例如:

我torch.save时,参数key中字符串前自动添加了'module.'

因此,在torch.load后,需要去掉'module.'

方法如下:

model = Model()
model_para_dict_temp = torch.load('xxx.pth')
model_para_dict = {}
for key_i in model_para_dict_temp.keys():
    model_para_dict[key_i[7:]] = model_para_dict_temp[key_i]  # 删除掉前7个字符'module.'
del model_para_dict_temp
model.load_state_dict(model_para_dict)

 

posted @ 2022-07-24 16:09  Real_Tourist  阅读(3051)  评论(0)    收藏  举报