pytorch 保存模型问题

pytorch模型保存两种方式:

  1. 保存参数字典
  • 模型保存
    torch.save(net.state_dict(),PATH)
  • 模型载入
net=model()#初始实例化
net.load_state_dict(torch.load(PATH))
  1. 保存模型
  • 模型保存
    torch.save(net,PATH)
  • 模型载入
    net=torch.load(PATH)#不需要初始化

pytorch使用中发现net.state_dict只保存所有module层的偏置与权重值,不保存零散变量值。
验证代码

#encoding:utf-8
 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        self.qx=0
 
    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
 
def main():
    # Initialize model
    model = TheModelClass()
 
    #Initialize optimizer
    optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
    #print model's state_dict
    print('Model.state_dict:')
    for param_tensor in model.state_dict():
        #打印 key value字典
        print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
    #print optimizer's state_dict
    print('Optimizer,s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name,'\t',optimizer.state_dict()[var_name])
if __name__=='__main__':
    main()

model的state_dict输出中并没有qx变量值。
参考链接https://blog.csdn.net/qq_41845478/article/details/116023691

本文作者:心比天高xzh

本文链接:https://www.cnblogs.com/xzh-personal-issue/p/17087793.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   心比天高xzh  阅读(32)  评论(0编辑  收藏  举报
评论
收藏
关注
推荐
深色
回顶
收起
点击右上角即可分享
微信分享提示