pytorch学习笔记(9)--神经网络模型的保存与读取

一、网络模型的保存和加载

1、网络模型保存方法1

import torch
import torchvision


vgg16 = torchvision.models.vgg16(weights=False)
# 保存方法1:模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

运行上述代码会发现在其同路径下保存了神经网络模型文件:vgg16_model1.pth

加载代码:

import torch

# 方法1  -> 保存方法1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)

结果:

复制代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
复制代码

保存了网络模型及模型的参数:

注:没有预训练的模型不是没有参数,而是参数在初始化的状态

2、网络模型保存方法2

保存的是模型参数(官方推荐):

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=False)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
import torch

model = torch.load("vgg16_method2.pth")
print(model)

 结果:

复制代码
OrderedDict([('features.0.weight', tensor([[[[-9.0302e-02,  6.1546e-02, -1.7735e-02],
          [ 1.1606e-01, -1.7557e-02, -5.4266e-02],
          [-3.0833e-02,  2.3019e-02,  2.2968e-02]],

         [[-3.5706e-02, -3.8619e-02,  2.7329e-02],
          [ 1.0525e-02,  7.0172e-02, -4.3097e-02],
          [-7.9473e-03, -2.8735e-02, -4.3932e-02]],

         [[ 6.6814e-02, -6.1849e-02, -9.8496e-02],
          [-5.7835e-02,  3.3374e-02,  3.2937e-02],
          [-4.3170e-02, -3.1252e-02,  1.1314e-01]]],


        [[[ 6.6068e-02, -6.5313e-02, -8.0335e-02],
          [-1.5587e-02,  1.1784e-02, -8.8468e-03],
          [ 7.2871e-02,  7.5150e-02, -7.2230e-02]],

         [[-3.7871e-02,  1.8217e-02,  1.1531e-01],
          [ 5.7616e-02, -1.2748e-01,  2.3816e-02],
          [-4.1781e-02, -2.1523e-02,  6.2196e-02]],

         [[-2.0698e-03,  8.8641e-02,  3.1991e-02],
          [-8.9041e-02, -1.1210e-01, -7.8223e-04],
          [-2.9659e-02, -1.5199e-01,  3.9977e-06]]],

      ......
复制代码

 两种模型保存的大小不一样:

 从上述输出结果中得到的结果是字典类型,其中参数的值也一起输出来了,如果想要查看具体的网络结构,则需要增加下述代码:

# 方式2-> 保存方式2,加载模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2.pth"))  # 输出完整的模型结构,与第一种方式输出的模型结构相同
print(vgg16)

 

posted @   helloWorldhelloWorld  阅读(609)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示