pytorch学习001- -如何保存模型
保存和加载模型
只保存模型的参数
保存
torch.save(model.state_dict(),'xxx.pth')
加载
model = net() #首先要先定义网络模型
state_dict = torch.load('xxx.pth') # 读取pth文件中的参数
model.load_state_dict(state_dict['model']) #将参数导入模型
这种方法操作比较麻烦,但是比较节省内存。
official example
class MyModule(torch.nn.Module):
m = MyModule()
m.state_dict()
torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)
附加保存一些其他信息
torch.save({
'epoch': epoch + opt.start,
'model_state_dict': model.state_dict()
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_losses.avg},
}, os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch + opt.start)))
保存整个模型
保存
torch.save(net, 'xxx.pt')
加载
test = torch.load('xxx.pt') #注意其中pt文件的路径
这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律