pytorch学习001- -如何保存模型

保存和加载模型

只保存模型的参数

保存

torch.save(model.state_dict(),'xxx.pth')

加载

model = net()	#首先要先定义网络模型
state_dict = torch.load('xxx.pth')	# 读取pth文件中的参数
model.load_state_dict(state_dict['model'])	#将参数导入模型

这种方法操作比较麻烦,但是比较节省内存。

official example

class MyModule(torch.nn.Module):


m = MyModule()
m.state_dict()

torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)

附加保存一些其他信息

torch.save({
             'epoch': epoch + opt.start,
            'model_state_dict': model.state_dict()
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': epoch_losses.avg},
        }, os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch + opt.start)))

保存整个模型

保存

torch.save(net, 'xxx.pt')

加载

test = torch.load('xxx.pt')	#注意其中pt文件的路径

这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。

官方文档

posted @ 2022-01-13 15:23  Keep_Silent  阅读(19)  评论(0编辑  收藏  举报