pytorch加载模型

1.加载全部模型:

net.load_state_dict(torch.load(net_para_pth))

2.加载部分模型

net_para_pth = './result/5826.pth'
pretrained_dict = torch.load(net_para_pth)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

3.改变某一层参数后加载

将该层名称改一下,然后用2中方法加载,比如,要将conv5的out_channels由256改为512。

将conv_5改为conv_5_chg,就可以顺利加载了,不改会报错哟

 

算是权宜之计了,还有什么好方法,希望多多指教

 

4.单GPU/CPU加载多GPU训练的网络

正常情况下,多GPU保存模型应该加上.module,然后加载时即使是单GPU也不会出问题,但是如果保存时忘记加,加载时就需要多一道手续

参考:https://blog.csdn.net/CV_YOU/article/details/86670188

def load_GPU(model, model_path, mapLoc='cpu'):
    state_dict = torch.load(model_path, map_location=mapLoc)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    return model

 也就是重新定义一个OrderedDict,然后将state_dict键值中的.module去掉

 


posted on 2019-11-05 10:24  江南烟雨尘  阅读(4502)  评论(4编辑  收藏  举报

导航