pytorch 保存模型问题
pytorch模型保存两种方式:
- 保存参数字典
- 模型保存
torch.save(net.state_dict(),PATH)
- 模型载入
net=model()#初始实例化
net.load_state_dict(torch.load(PATH))
- 保存模型
- 模型保存
torch.save(net,PATH)
- 模型载入
net=torch.load(PATH)#不需要初始化
pytorch使用中发现net.state_dict只保存所有module层的偏置与权重值,不保存零散变量值。
验证代码
#encoding:utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
#define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
self.qx=0
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
def main():
# Initialize model
model = TheModelClass()
#Initialize optimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
#print model's state_dict
print('Model.state_dict:')
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
#print optimizer's state_dict
print('Optimizer,s state_dict:')
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
if __name__=='__main__':
main()
model的state_dict输出中并没有qx变量值。
参考链接https://blog.csdn.net/qq_41845478/article/details/116023691
本文作者:心比天高xzh
本文链接:https://www.cnblogs.com/xzh-personal-issue/p/17087793.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
分类:
Python
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步