pytorch保存模型及加载模型

Class TestModle(nn.Module):
	def __init__(self):
		self.conv = nn.Conv(3, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		...
	def forward(self, x):
		...
	....

假如有这样一个模型

一、使用状态字典保存模型参数(官方推荐用法)

保存模型

torch.save(model.state_dict(), PATH)
模型一般选择pt后缀结尾

载入模型

由于我们仅仅保存模型的权重参数,没有模型的结构是无法载入参数的

model = TestModel()
model.load_state_dict(torch.load(PATH))
model.eval()
  • 注意:这里我们载入参数后使用eval()后再对输入的内容进行推理,因为eval会把模型内的标准化和dropout等功能给禁用了。才能输出正确的推理结果

二、保存整个模型,载入模型(无需加载模型的结构)

保存模型
torhc.save(model, PATH)
加载模型
model = torch.load(PATH)

以上是本人常用的两种方法,实测有效。但是由于pytorch并不熟练,如果想了解更多,可移步到这
https://zhuanlan.zhihu.com/p/82038049

posted @ 2023-07-11 17:37  waterdoor  阅读(29)  评论(0编辑  收藏  举报