PyTorch学习系列(十四)——保存训练好的模型

PyTorch学习系列(十四)——保存训练好的模型

PyTorch提供了两种保存训练好的模型的方法。 
第一种是只保存模型参数,这也是推荐的方法:

#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二种方法保存整个模型:

#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)

参考 
[1] http://pytorch.org/docs/notes/serialization.html

posted @ 2017-12-31 09:26  菜鸡一枚  阅读(1694)  评论(0)    收藏  举报