pytorch保存与加载模型

https://zhuanlan.zhihu.com/p/38056115

链接中内容包括:

  • 保存模型与加载模型
  • 冻结一部分参数,训练另一部分参数
  • 采用不同的学习率进行训练

1.保存与加载模型

//保存整个网络
torch.save(net, PATH) 
net=torch.load(PATH)

//保存网络参数,占空间少
torch.save(net.state_dict(),PATH)
model_dict=model.load_state_dict(torch.load(PATH))

另外还可以保存优化器的信息:

torch.save({'epoch': epochID + 1,'optimizer': optimizer.state_dict()});

以字典格式保存。

 

posted @ 2021-04-03 01:42  lypbendlf  阅读(62)  评论(0编辑  收藏  举报