pytorch保存模型及加载模型

Class TestModle(nn.Module):
	def __init__(self):
		self.conv = nn.Conv(3, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		...
	def forward(self, x):
		...
	....

假如有这样一个模型

一、使用状态字典保存模型参数(官方推荐用法)

保存模型

torch.save(model.state_dict(), PATH)
模型一般选择pt后缀结尾

载入模型

由于我们仅仅保存模型的权重参数,没有模型的结构是无法载入参数的

model = TestModel()
model.load_state_dict(torch.load(PATH))
model.eval()
  • 注意:这里我们载入参数后使用eval()后再对输入的内容进行推理,因为eval会把模型内的标准化和dropout等功能给禁用了。才能输出正确的推理结果

二、保存整个模型,载入模型(无需加载模型的结构)

保存模型
torhc.save(model, PATH)
加载模型
model = torch.load(PATH)

以上是本人常用的两种方法,实测有效。但是由于pytorch并不熟练,如果想了解更多,可移步到这
https://zhuanlan.zhihu.com/p/82038049

posted @   waterdoor  阅读(43)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示