PyTorch模型保存与加载

保存与加载整个模型

保存整个模型,包括网络结构和权重参数,保存后的文件用torch.load()加载后的类型是定义的网络结构类,如class CNN:

torch.save(model, "model.pkl")

加载整个模型:

model = torch.load("model.pkl")

只保存与加载模型参数

只保存模型参数,保存后的文件使用torch.load()加载后类型是collections.OrderedDict:

torch.save(model.state_dict(), "model_parameter.pkl")

由于模型文件中只保存了参数、没有网络结构,所以加载模型时需要先指定网络结构,复制训练时定义的网络结构即可:

model = Model() # 使用训练时定义的模型网络结构
model.load_state_dict((torch.load("model_parameter.pkl")))

保存与加载自定义模型

可以自定义模型中保存哪些信息,例如网络结构、模型权重参数、优化器参数等:

custom_model = {'net': CNN(),
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }

torch.save(custom_model, 'custom_model.pkl')

保存后的文件使用torch.load()后可以通过字典取值方式获取net、model_state_dict等键值内容。

custom_model = torch.load('custom_model.pkl')
model = custom_model['net']
model.load_state_dict(custom_model['model_state_dict'])

# predict
model(data)

参考

https://zhuanlan.zhihu.com/p/73893187

posted @ 2022-12-27 17:12  Init0ne  阅读(180)  评论(0编辑  收藏  举报