torch.save torch.load 加载和保存模型

https://pytorch123.com/ThirdSection/SaveModel/ 这个链接非常的详细!

1、#保存整个网络 torch.save(net, PATH)

# 保存网络中的参数, 速度快,占空间少 torch.save(net.state_dict(),PATH)

#--------------------------------------------------

#针对上面一般的保存方法,加载的方法分别是:

model_dict=torch.load(PATH)

model_dict=model.load_state_dict(torch.load(PATH))

2、然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

 

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
                            'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                           checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')


posted @ 2022-03-08 17:57  Tomorrow1126  阅读(1264)  评论(0编辑  收藏  举报