PyTorch基础内容

修改模型

import torchvision.models as models
net = models.resnet50()
# 查看模型定义
print(net)
# output
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    ......
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

可以看到,Resnet50最后一层(fc)默认输出1000个节点。
若想将该模型应用于10分类任务中,则需要将最后的输出节点数修改为10。

import torch.nn as nn
from collections import OrderedDict
# 一层全连接层可能太少,可以再加一层。
classifier = nn.Sequential(OrderedDict([('fc1',nn.Linear(2048,128)),
                          ('relu',nn.ReLU()),
                          ('dropout',nn.Dropout(0.5)),
                          ('fc2',nn.Linear(128,10)),
                          ('output',nn.Softmax(dim=1))]))
# 将net的fc层替换为自定义的classifier
net.fc = classifier

再输出net,可以看到最后一层的fc已修改为定义的内容

  (fc): Sequential(
    (fc1): Linear(in_features=2048, out_features=128, bias=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
    (output): Softmax(dim=1)
  )

PyTorch模型保存与读取

模型保存

  • 模型存储数据格式:pt, pth, pkl
import os
import torch

# 希望使用的GPU编号
os.environ['CUDA_CISIBLE_DEVICES'] = '0'
net.cuda()
# 保存模型,数据格式可以为 pt, pth, pkl
torch.save(net, './model.pt')
# 保存权重
torch.save(net.state_dict(), './weight.pt')

模型加载

# 读取模型
loaded_model = torch.load('./model.pt')
# 将权重加载到模型上,也可先读取到一个变量中,再为loaded_model赋值,分两步进行
loaded_model.state_dict = torch.load('./weight.pt')
loaded_model.cuda()
loaded_dict = torch.load('./weight.pt')
print(loaded_dict.keys())
# odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 
......
'fc.fc1.weight', 'fc.fc1.bias', 'fc.fc2.weight', 'fc.fc2.bias'])
posted @ 2022-04-06 20:39  ArdenWang  阅读(40)  评论(0编辑  收藏  举报