PyTorch模型保存与加载
保存与加载整个模型
保存整个模型,包括网络结构和权重参数,保存后的文件用torch.load()加载后的类型是定义的网络结构类,如class CNN:
torch.save(model, "model.pkl")
加载整个模型:
model = torch.load("model.pkl")
只保存与加载模型参数
只保存模型参数,保存后的文件使用torch.load()加载后类型是collections.OrderedDict:
torch.save(model.state_dict(), "model_parameter.pkl")
由于模型文件中只保存了参数、没有网络结构,所以加载模型时需要先指定网络结构,复制训练时定义的网络结构即可:
model = Model() # 使用训练时定义的模型网络结构
model.load_state_dict((torch.load("model_parameter.pkl")))
保存与加载自定义模型
可以自定义模型中保存哪些信息,例如网络结构、模型权重参数、优化器参数等:
custom_model = {'net': CNN(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(custom_model, 'custom_model.pkl')
保存后的文件使用torch.load()后可以通过字典取值方式获取net、model_state_dict等键值内容。
custom_model = torch.load('custom_model.pkl')
model = custom_model['net']
model.load_state_dict(custom_model['model_state_dict'])
# predict
model(data)