pytorch model

网络定义

import torch as torch
import torch.nn as nn
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        layer1 = nn.Sequential()
        layer1.add_module('conv1',nn.Conv2d(1,6,5))
        layer1.add_module('pool1',nn.MaxPool2d(2,2))
        self.layer1 = layer1

        layer2 = nn.Sequential()
        layer2.add_module('conv2',nn.Conv2d(6,16,5))
        layer2.add_module('pool2',nn.MaxPool2d(2,2))
        self.layer2 = layer2

        layer3 = nn.Sequential()
        layer3.add_module('fc1',nn.Linear(16*5*5,120))
        layer3.add_module('fc2',nn.Linear(120,84))
        layer3.add_module('fc3',nn.Linear(84,10))
        self.layer3 = layer3

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.size(0),-1)#转换(降低)数据维度,进入全连接层
        x = self.layer3(x)
        return x

#代入数据检验
y = torch.randn(1,1,32,32)
model = LeNet()
out = model(y)
print(model)
print(out)

输出如下:

LeNet(
  (layer1): Sequential(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )
)
tensor([[ 0.0211,  0.1407, -0.1831, -0.1182,  0.0221,  0.1467, -0.0523, -0.0663,
         -0.0351, -0.0434]], grad_fn=<AddmmBackward>)

def set_bn_momentum(model, momentum=0.1):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.momentum = momentum

def fix_bn(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

model.named_children 返回名字 和 操作

print("*"*50)
for name, module in model.named_children():
    print(name)
    print(module)

打印如下:

layer1
Sequential(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
layer2
Sequential(
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
layer3
Sequential(
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

可以用于forward,直接对输入遍历操作

  def forward(self, x):
        for name, module in self.named_children():
            x = module(x)

model.modules() 可用于参数初始化

print("#"*200)
cnt = 0
for name in model.modules():
    cnt += 1
    print('-------------------------------------------------------cnt=',cnt)
    print(name)

输出如下:

########################################################################################################################################################################################################
-------------------------------------------------------cnt= 1
LeNet(
  (layer1): Sequential(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )
)
-------------------------------------------------------cnt= 2
Sequential(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
-------------------------------------------------------cnt= 3
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
-------------------------------------------------------cnt= 4
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
-------------------------------------------------------cnt= 5
Sequential(
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
-------------------------------------------------------cnt= 6
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
-------------------------------------------------------cnt= 7
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
-------------------------------------------------------cnt= 8
Sequential(
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
-------------------------------------------------------cnt= 9
Linear(in_features=400, out_features=120, bias=True)
-------------------------------------------------------cnt= 10
Linear(in_features=120, out_features=84, bias=True)
-------------------------------------------------------cnt= 11
Linear(in_features=84, out_features=10, bias=True)

model.modules()用于参数初始化

cnt = 0
for name in model.modules():
    cnt += 1
    print('-------------------------------------------------------cnt=',cnt)
    print(name)
    if isinstance(name, nn.Conv2d):
        print('------------------isinstance(name, nn.Conv2d)------------------')
        print(name.weight)
        print(name.bias)
        print('--end----------------isinstance(name, nn.Conv2d)------------end------')

    if isinstance(name, nn.Conv2d):
        nn.init.kaiming_normal_(name.weight)
    elif isinstance(name, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(name.weight, 1)
        nn.init.constant_(name.bias, 0)

其中参数部分输出如下:

------------------isinstance(name, nn.Conv2d)------------------
Parameter containing:
tensor([[[[-0.1561, -0.0194, -0.0260, -0.0042,  0.1716],
          [ 0.1181, -0.1380, -0.0448,  0.0674, -0.1972],
          [-0.0197,  0.0359,  0.1186,  0.0876, -0.0395],
          [-0.0619,  0.0095, -0.0702,  0.0122,  0.1573],
          [ 0.1170,  0.1758, -0.1655,  0.1489, -0.0956]]],
       ...
  [[[-0.1337, -0.0562, -0.0624,  0.0885, -0.0640],
          [-0.0302, -0.1192, -0.0637,  0.0083,  0.0181],
          [ 0.1388, -0.1690,  0.1132,  0.1686, -0.1189],
          [-0.0246, -0.1649, -0.1817, -0.0330, -0.0430],
          [ 0.0672, -0.0671,  0.0469,  0.1284,  0.1420]]]], requires_grad=True)
Parameter containing:
tensor([ 0.0548,  0.0547,  0.1328, -0.0452,  0.1668, -0.1915],
       requires_grad=True)
--end----------------isinstance(name, nn.Conv2d)------------end------

model.modules()用于设置bn参数和冻结bn

def set_bn_momentum(model, momentum=0.1):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = momentum

def fix_bn(model):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

其他的可以参考:

https://blog.csdn.net/MrR1ght/article/details/105246412
model.children(): 返回模型的所有子模块的迭代器
model.modules():返回模型的所有模块(不仅仅是子模块,还包含当前模块)
model.named_children():返回当前子模块的迭代器。名字:模块
model.named_modules():

model.parameters() || torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]

参数:

params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float) – 学习率
momentum (float, 可选) – 动量因子(默认:0)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认:0)
dampening (float, 可选) – 动量的抑制因子(默认:0)
nesterov (bool, 可选) – 使用Nesterov动量(默认:False)

例子:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

这里对model.parameters()比较好奇
于是我打印:

print(model.parameters())

打印出这玩意:
<generator object Module.parameters at 0x7f1d2272d728>
感觉是一个指针,于是我在这样打印:

print(*model.parameters())

这回输出一大串数字:部分如下:

Parameter containing:
tensor([[[[-0.1751,  0.1829,  0.1973,  0.0780,  0.1220],
          [-0.0497,  0.0943,  0.0827,  0.1829,  0.0239],
          [-0.1044,  0.1268,  0.0716, -0.0100,  0.1991],
          [-0.0730,  0.1762, -0.0787,  0.0686, -0.0069],
          [ 0.1316,  0.0897, -0.1068,  0.0744,  0.0524]]],

        [[[-0.1034, -0.1946, -0.1312,  0.1076,  0.0129],
          [ 0.0450,  0.0552,  0.1448, -0.1283, -0.1868],
          [-0.0260, -0.1928,  0.0519, -0.0493, -0.1028],
          [-0.0936,  0.1719, -0.0997,  0.0008,  0.0871],
          [ 0.0995, -0.1274,  0.0388,  0.0779,  0.0006]]],

        [[[ 0.1846, -0.0723,  0.0649, -0.0169, -0.1595],
          [ 0.0145, -0.1893,  0.0784, -0.0886, -0.0044],
          [ 0.1914, -0.1009, -0.0736, -0.0992, -0.1618],
          [-0.0291,  0.0997,  0.0549,  0.1267, -0.1661],
          [-0.1333,  0.0168,  0.0648,  0.1047, -0.1506]]],
            ...
     -4.0503e-03,  9.4014e-02, -8.5686e-02,  7.7082e-02]],
       requires_grad=True) Parameter containing:
tensor([-0.0106,  0.0448, -0.0001, -0.0914, -0.0310, -0.0628,  0.0899, -0.0047,
        -0.0390, -0.0291], requires_grad=True)

自定义参数

optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

还看到另外的写法:

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

# Define Optimizer
optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                       weight_decay=args.weight_decay, nesterov=args.nesterov)

打印网络总参数量

 params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

打印如下:

该层参数和:256
该层的结构:[256, 2048, 3, 3]
该层参数和:4718592
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 2048, 1, 1]
该层参数和:524288
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 1280, 1, 1]
该层参数和:327680
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 304, 3, 3]
该层参数和:700416
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[26, 256, 1, 1]
该层参数和:6656
该层的结构:[26]
该层参数和:26
总参数数量和:58755258

net.parameters() net.named_parameters() 显示网络参数

for parameters in net.parameters():
    print(parameters)

输出如下:

Parameter containing:
tensor([[[[-0.0104, -0.0555,  0.1417],
          [-0.3281, -0.0367,  0.0208],
          [-0.0894, -0.0511, -0.1253]]],


        [[[-0.1724,  0.2141, -0.0895],
          [ 0.0116,  0.1661, -0.1853],
          [-0.1190,  0.1292, -0.2451]]],

2

for name,parameters in net.named_parameters():
    print(name,':',parameters.size())

输出如下:

module.backbone.conv1.weight : torch.Size([64, 3, 7, 7])
module.backbone.bn1.weight : torch.Size([64])
module.backbone.bn1.bias : torch.Size([64])
module.backbone.layer1.0.conv1.weight : torch.Size([64, 64, 1, 1])
module.backbone.layer1.0.bn1.weight : torch.Size([64])
module.backbone.layer1.0.bn1.bias : torch.Size([64])
module.backbone.layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.0.bn2.weight : torch.Size([64])
module.backbone.layer1.0.bn2.bias : torch.Size([64])
module.backbone.layer1.0.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.0.bn3.weight : torch.Size([256])
module.backbone.layer1.0.bn3.bias : torch.Size([256])
module.backbone.layer1.0.downsample.0.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.0.downsample.1.weight : torch.Size([256])
module.backbone.layer1.0.downsample.1.bias : torch.Size([256])
module.backbone.layer1.1.conv1.weight : torch.Size([64, 256, 1, 1])
module.backbone.layer1.1.bn1.weight : torch.Size([64])
module.backbone.layer1.1.bn1.bias : torch.Size([64])
module.backbone.layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.1.bn2.weight : torch.Size([64])
module.backbone.layer1.1.bn2.bias : torch.Size([64])
module.backbone.layer1.1.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.1.bn3.weight : torch.Size([256])
module.backbone.layer1.1.bn3.bias : torch.Size([256])
module.backbone.layer1.2.conv1.weight : torch.Size([64, 256, 1, 1])
module.backbone.layer1.2.bn1.weight : torch.Size([64])
module.backbone.layer1.2.bn1.bias : torch.Size([64])
module.backbone.layer1.2.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.2.bn2.weight : torch.Size([64])
module.backbone.layer1.2.bn2.bias : torch.Size([64])
module.backbone.layer1.2.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.2.bn3.weight : torch.Size([256])
module.backbone.layer1.2.bn3.bias : torch.Size([256])
module.backbone.layer2.0.conv1.weight : torch.Size([128, 256, 1, 1])
module.backbone.layer2.0.bn1.weight : torch.Size([128])
module.backbone.layer2.0.bn1.bias : torch.Size([128])
module.backbone.layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
module.backbone.layer2.0.bn2.weight : torch.Size([128])
module.backbone.layer2.0.bn2.bias : torch.Size([128])
module.backbone.layer2.0.conv3.weight : torch.Size([512, 128, 1, 1])
module.backbone.layer2.0.bn3.weight : torch.Size([512])
module.backbone.layer2.0.bn3.bias : torch.Size([512])

2021.01.21补充例子:

if verbose:
   print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
   for i, (name, p) in enumerate(model.named_parameters()):
       name = name.replace('module_list.', '')
       print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
            (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))

输出如下:

layer                                     name  gradient   parameters                shape         mu      sigma
    0                          0.Conv2d.weight      True          864        [32, 3, 3, 3]  -8.67e-05      0.112
    1                     0.BatchNorm2d.weight      True           32                 [32]          1          0
    2                       0.BatchNorm2d.bias      True           32                 [32]          0          0
    3                          1.Conv2d.weight      True        18432       [64, 32, 3, 3]   0.000242      0.034
    4                     1.BatchNorm2d.weight      True           64                 [64]          1          0
    5                       1.BatchNorm2d.bias      True           64                 [64]          0          0
    6                          2.Conv2d.weight      True         2048       [32, 64, 1, 1]   0.000321      0.072
    7                     2.BatchNorm2d.weight      True           32                 [32]          1          0
    8                       2.BatchNorm2d.bias      True           32                 [32]          0          0
    9                          3.Conv2d.weight      True        18432       [64, 32, 3, 3]   0.000188     0.0339
   10                     3.BatchNorm2d.weight      True           64                 [64]          1          0
   11                       3.BatchNorm2d.bias      True           64                 [64]          0          0
   12                          5.Conv2d.weight      True        73728      [128, 64, 3, 3]  -2.89e-05     0.0241
   13                     5.BatchNorm2d.weight      True          128                [128]          1          0
   14                       5.BatchNorm2d.bias      True          128                [128]          0          0
   15                          6.Conv2d.weight      True         8192      [64, 128, 1, 1]   0.000717     0.0513
   16                     6.BatchNorm2d.weight      True           64                 [64]          1          0
   17                       6.BatchNorm2d.bias      True           64                 [64]          0          0
   18                          7.Conv2d.weight      True        73728      [128, 64, 3, 3]   0.000112     0.0241
   19                     7.BatchNorm2d.weight      True          128                [128]          1          0
   20                       7.BatchNorm2d.bias      True          128                [128]          0          0
   21                          9.Conv2d.weight      True         8192      [64, 128, 1, 1]  -0.000266      0.051
   22                     9.BatchNorm2d.weight      True           64                 [64]          1          0
   23                       9.BatchNorm2d.bias      True           64                 [64]          0          0
   24                         10.Conv2d.weight      True        73728      [128, 64, 3, 3]    -0.0002     0.0241
   25                    10.BatchNorm2d.weight      True          128                [128]          1          0
   26                      10.BatchNorm2d.bias      True          128                [128]          0          0
   27                         12.Conv2d.weight      True       294912     [256, 128, 3, 3]   4.95e-05      0.017
   28                    12.BatchNorm2d.weight      True          256                [256]          1          0
   29                      12.BatchNorm2d.bias      True          256                [256]          0          0
   30                         13.Conv2d.weight      True        32768     [128, 256, 1, 1]   -9.5e-05      0.036
   31                    13.BatchNorm2d.weight      True          128                [128]          1          0
   32                      13.BatchNorm2d.bias      True          128                [128]          0          0
   33                         14.Conv2d.weight      True       294912     [256, 128, 3, 3]  -2.46e-06      0.017
   34                    14.BatchNorm2d.weight      True          256                [256]          1          0
   35                      14.BatchNorm2d.bias      True          256                [256]          0          0
   36                         16.Conv2d.weight      True        32768     [128, 256, 1, 1]   0.000225     0.0359
   37                    16.BatchNorm2d.weight      True          128                [128]          1          0
   38                      16.BatchNorm2d.bias      True          128                [128]          0          0
   39                         17.Conv2d.weight      True       294912     [256, 128, 3, 3]  -5.76e-05      0.017
   40                    17.BatchNorm2d.weight      True          256                [256]          1          0
   41                      17.BatchNorm2d.bias      True          256                [256]          0          0
   42                         19.Conv2d.weight      True        32768     [128, 256, 1, 1]  -6.58e-05      0.036
   43                    19.BatchNorm2d.weight      True          128                [128]          1          0
   44                      19.BatchNorm2d.bias      True          128                [128]          0          0
   45                         20.Conv2d.weight      True       294912     [256, 128, 3, 3]  -1.72e-06      0.017
   46                    20.BatchNorm2d.weight      True          256                [256]          1          0
   47                      20.BatchNorm2d.bias      True          256                [256]          0          0
   48                         22.Conv2d.weight      True        32768     [128, 256, 1, 1]   0.000157      0.036
   49                    22.BatchNorm2d.weight      True          128                [128]          1          0
   50                      22.BatchNorm2d.bias      True          128                [128]          0          0
   51                         23.Conv2d.weight      True       294912     [256, 128, 3, 3]   2.92e-05      0.017
   52                    23.BatchNorm2d.weight      True          256                [256]          1          0
   53                      23.BatchNorm2d.bias      True          256                [256]          0          0
   54                         25.Conv2d.weight      True        32768     [128, 256, 1, 1]   0.000226     0.0361
   55                    25.BatchNorm2d.weight      True          128                [128]          1          0
   56                      25.BatchNorm2d.bias      True          128                [128]          0          0
   57                         26.Conv2d.weight      True       294912     [256, 128, 3, 3]   1.74e-05      0.017
   58                    26.BatchNorm2d.weight      True          256                [256]          1          0
   59                      26.BatchNorm2d.bias      True          256                [256]          0          0
   60                         28.Conv2d.weight      True        32768     [128, 256, 1, 1]   0.000182      0.036
   61                    28.BatchNorm2d.weight      True          128                [128]          1          0
   62                      28.BatchNorm2d.bias      True          128                [128]          0          0
   63                         29.Conv2d.weight      True       294912     [256, 128, 3, 3]   5.26e-07      0.017
   64                    29.BatchNorm2d.weight      True          256                [256]          1          0
   65                      29.BatchNorm2d.bias      True          256                [256]          0          0
   66                         31.Conv2d.weight      True        32768     [128, 256, 1, 1]  -0.000297     0.0361
   67                    31.BatchNorm2d.weight      True          128                [128]          1          0
   68                      31.BatchNorm2d.bias      True          128                [128]          0          0
   69                         32.Conv2d.weight      True       294912     [256, 128, 3, 3]   4.21e-05      0.017
   70                    32.BatchNorm2d.weight      True          256                [256]          1          0
   71                      32.BatchNorm2d.bias      True          256                [256]          0          0
   72                         34.Conv2d.weight      True        32768     [128, 256, 1, 1]   2.84e-05      0.036
   73                    34.BatchNorm2d.weight      True          128                [128]          1          0
   74                      34.BatchNorm2d.bias      True          128                [128]          0          0
   75                         35.Conv2d.weight      True       294912     [256, 128, 3, 3]  -4.58e-05      0.017
   76                    35.BatchNorm2d.weight      True          256                [256]          1          0
   77                      35.BatchNorm2d.bias      True          256                [256]          0          0
   78                         37.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   2.59e-06      0.012
   79                    37.BatchNorm2d.weight      True          512                [512]          1          0
   80                      37.BatchNorm2d.bias      True          512                [512]          0          0
   81                         38.Conv2d.weight      True       131072     [256, 512, 1, 1]  -2.42e-05     0.0255
   82                    38.BatchNorm2d.weight      True          256                [256]          1          0
   83                      38.BatchNorm2d.bias      True          256                [256]          0          0
   84                         39.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   6.23e-06      0.012
   85                    39.BatchNorm2d.weight      True          512                [512]          1          0
   86                      39.BatchNorm2d.bias      True          512                [512]          0          0
   87                         41.Conv2d.weight      True       131072     [256, 512, 1, 1]    3.8e-05     0.0255
   88                    41.BatchNorm2d.weight      True          256                [256]          1          0
   89                      41.BatchNorm2d.bias      True          256                [256]          0          0
   90                         42.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -1.15e-05      0.012
   91                    42.BatchNorm2d.weight      True          512                [512]          1          0
   92                      42.BatchNorm2d.bias      True          512                [512]          0          0
   93                         44.Conv2d.weight      True       131072     [256, 512, 1, 1]   1.25e-07     0.0254
   94                    44.BatchNorm2d.weight      True          256                [256]          1          0
   95                      44.BatchNorm2d.bias      True          256                [256]          0          0
   96                         45.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -1.02e-05      0.012
   97                    45.BatchNorm2d.weight      True          512                [512]          1          0
   98                      45.BatchNorm2d.bias      True          512                [512]          0          0
   99                         47.Conv2d.weight      True       131072     [256, 512, 1, 1]    0.00018     0.0255
  100                    47.BatchNorm2d.weight      True          256                [256]          1          0
  101                      47.BatchNorm2d.bias      True          256                [256]          0          0
  102                         48.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -1.22e-05      0.012
  103                    48.BatchNorm2d.weight      True          512                [512]          1          0
  104                      48.BatchNorm2d.bias      True          512                [512]          0          0
  105                         50.Conv2d.weight      True       131072     [256, 512, 1, 1]  -2.25e-05     0.0255
  106                    50.BatchNorm2d.weight      True          256                [256]          1          0
  107                      50.BatchNorm2d.bias      True          256                [256]          0          0
  108                         51.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   6.82e-06      0.012
  109                    51.BatchNorm2d.weight      True          512                [512]          1          0
  110                      51.BatchNorm2d.bias      True          512                [512]          0          0
  111                         53.Conv2d.weight      True       131072     [256, 512, 1, 1]   -6.9e-05     0.0255
  112                    53.BatchNorm2d.weight      True          256                [256]          1          0
  113                      53.BatchNorm2d.bias      True          256                [256]          0          0
  114                         54.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   1.89e-06      0.012
  115                    54.BatchNorm2d.weight      True          512                [512]          1          0
  116                      54.BatchNorm2d.bias      True          512                [512]          0          0
  117                         56.Conv2d.weight      True       131072     [256, 512, 1, 1]    0.00015     0.0255
  118                    56.BatchNorm2d.weight      True          256                [256]          1          0
  119                      56.BatchNorm2d.bias      True          256                [256]          0          0
  120                         57.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   2.61e-05      0.012
  121                    57.BatchNorm2d.weight      True          512                [512]          1          0
  122                      57.BatchNorm2d.bias      True          512                [512]          0          0
  123                         59.Conv2d.weight      True       131072     [256, 512, 1, 1]  -0.000128     0.0256
  124                    59.BatchNorm2d.weight      True          256                [256]          1          0
  125                      59.BatchNorm2d.bias      True          256                [256]          0          0
  126                         60.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -1.97e-06      0.012
  127                    60.BatchNorm2d.weight      True          512                [512]          1          0
  128                      60.BatchNorm2d.bias      True          512                [512]          0          0
  129                         62.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]   1.53e-06     0.0085
  130                    62.BatchNorm2d.weight      True         1024               [1024]          1          0
  131                      62.BatchNorm2d.bias      True         1024               [1024]          0          0
  132                         63.Conv2d.weight      True       524288    [512, 1024, 1, 1]  -1.84e-05     0.0181
  133                    63.BatchNorm2d.weight      True          512                [512]          1          0
  134                      63.BatchNorm2d.bias      True          512                [512]          0          0
  135                         64.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]   2.17e-07     0.0085
  136                    64.BatchNorm2d.weight      True         1024               [1024]          1          0
  137                      64.BatchNorm2d.bias      True         1024               [1024]          0          0
  138                         66.Conv2d.weight      True       524288    [512, 1024, 1, 1]   2.39e-05      0.018
  139                    66.BatchNorm2d.weight      True          512                [512]          1          0
  140                      66.BatchNorm2d.bias      True          512                [512]          0          0
  141                         67.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]  -1.41e-06    0.00851
  142                    67.BatchNorm2d.weight      True         1024               [1024]          1          0
  143                      67.BatchNorm2d.bias      True         1024               [1024]          0          0
  144                         69.Conv2d.weight      True       524288    [512, 1024, 1, 1]  -1.94e-05      0.018
  145                    69.BatchNorm2d.weight      True          512                [512]          1          0
  146                      69.BatchNorm2d.bias      True          512                [512]          0          0
  147                         70.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]   1.07e-06    0.00851
  148                    70.BatchNorm2d.weight      True         1024               [1024]          1          0
  149                      70.BatchNorm2d.bias      True         1024               [1024]          0          0
  150                         72.Conv2d.weight      True       524288    [512, 1024, 1, 1]   3.62e-05     0.0181
  151                    72.BatchNorm2d.weight      True          512                [512]          1          0
  152                      72.BatchNorm2d.bias      True          512                [512]          0          0
  153                         73.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]   4.51e-06     0.0085
  154                    73.BatchNorm2d.weight      True         1024               [1024]          1          0
  155                      73.BatchNorm2d.bias      True         1024               [1024]          0          0
  156                         75.Conv2d.weight      True       524288    [512, 1024, 1, 1]   2.73e-05      0.018
  157                    75.BatchNorm2d.weight      True          512                [512]          1          0
  158                      75.BatchNorm2d.bias      True          512                [512]          0          0
  159                         76.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]   2.64e-06     0.0085
  160                    76.BatchNorm2d.weight      True         1024               [1024]          1          0
  161                      76.BatchNorm2d.bias      True         1024               [1024]          0          0
  162                         77.Conv2d.weight      True       524288    [512, 1024, 1, 1]  -3.97e-05      0.018
  163                    77.BatchNorm2d.weight      True          512                [512]          1          0
  164                      77.BatchNorm2d.bias      True          512                [512]          0          0
  165                         78.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]  -7.67e-08    0.00851
  166                    78.BatchNorm2d.weight      True         1024               [1024]          1          0
  167                      78.BatchNorm2d.bias      True         1024               [1024]          0          0
  168                         79.Conv2d.weight      True       524288    [512, 1024, 1, 1]  -3.87e-05      0.018
  169                    79.BatchNorm2d.weight      True          512                [512]          1          0
  170                      79.BatchNorm2d.bias      True          512                [512]          0          0
  171                         80.Conv2d.weight      True  4.71859e+06    [1024, 512, 3, 3]  -3.03e-06    0.00851
  172                    80.BatchNorm2d.weight      True         1024               [1024]          1          0
  173                      80.BatchNorm2d.bias      True         1024               [1024]          0          0
  174                         81.Conv2d.weight      True        76800     [75, 1024, 1, 1]  -7.83e-06     0.0181
  175                           81.Conv2d.bias      True           75                 [75]      -2.94       1.31
  176                         84.Conv2d.weight      True       131072     [256, 512, 1, 1]  -7.19e-05     0.0255
  177                    84.BatchNorm2d.weight      True          256                [256]          1          0
  178                      84.BatchNorm2d.bias      True          256                [256]          0          0
  179                         87.Conv2d.weight      True       196608     [256, 768, 1, 1]  -7.45e-06     0.0208
  180                    87.BatchNorm2d.weight      True          256                [256]          1          0
  181                      87.BatchNorm2d.bias      True          256                [256]          0          0
  182                         88.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -4.27e-06      0.012
  183                    88.BatchNorm2d.weight      True          512                [512]          1          0
  184                      88.BatchNorm2d.bias      True          512                [512]          0          0
  185                         89.Conv2d.weight      True       131072     [256, 512, 1, 1]   4.33e-05     0.0255
  186                    89.BatchNorm2d.weight      True          256                [256]          1          0
  187                      89.BatchNorm2d.bias      True          256                [256]          0          0
  188                         90.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]   2.65e-06      0.012
  189                    90.BatchNorm2d.weight      True          512                [512]          1          0
  190                      90.BatchNorm2d.bias      True          512                [512]          0          0
  191                         91.Conv2d.weight      True       131072     [256, 512, 1, 1]  -4.59e-05     0.0255
  192                    91.BatchNorm2d.weight      True          256                [256]          1          0
  193                      91.BatchNorm2d.bias      True          256                [256]          0          0
  194                         92.Conv2d.weight      True  1.17965e+06     [512, 256, 3, 3]  -3.91e-06      0.012
  195                    92.BatchNorm2d.weight      True          512                [512]          1          0
  196                      92.BatchNorm2d.bias      True          512                [512]          0          0
  197                         93.Conv2d.weight      True        38400      [75, 512, 1, 1]   0.000177     0.0255
  198                           93.Conv2d.bias      True           75                 [75]      -2.94        1.3
  199                         96.Conv2d.weight      True        32768     [128, 256, 1, 1]  -0.000275      0.036
  200                    96.BatchNorm2d.weight      True          128                [128]          1          0
  201                      96.BatchNorm2d.bias      True          128                [128]          0          0
  202                         99.Conv2d.weight      True        49152     [128, 384, 1, 1]   9.75e-05     0.0295
  203                    99.BatchNorm2d.weight      True          128                [128]          1          0
  204                      99.BatchNorm2d.bias      True          128                [128]          0          0
  205                        100.Conv2d.weight      True       294912     [256, 128, 3, 3]   -1.5e-06      0.017
  206                   100.BatchNorm2d.weight      True          256                [256]          1          0
  207                     100.BatchNorm2d.bias      True          256                [256]          0          0
  208                        101.Conv2d.weight      True        32768     [128, 256, 1, 1]  -6.49e-05      0.036
  209                   101.BatchNorm2d.weight      True          128                [128]          1          0
  210                     101.BatchNorm2d.bias      True          128                [128]          0          0
  211                        102.Conv2d.weight      True       294912     [256, 128, 3, 3]   1.01e-05      0.017
  212                   102.BatchNorm2d.weight      True          256                [256]          1          0
  213                     102.BatchNorm2d.bias      True          256                [256]          0          0
  214                        103.Conv2d.weight      True        32768     [128, 256, 1, 1]   0.000229      0.036
  215                   103.BatchNorm2d.weight      True          128                [128]          1          0
  216                     103.BatchNorm2d.bias      True          128                [128]          0          0
  217                        104.Conv2d.weight      True       294912     [256, 128, 3, 3]  -1.62e-05      0.017
  218                   104.BatchNorm2d.weight      True          256                [256]          1          0
  219                     104.BatchNorm2d.bias      True          256                [256]          0          0
  220                        105.Conv2d.weight      True        19200      [75, 256, 1, 1]    0.00016     0.0361
  221                          105.Conv2d.bias      True           75                 [75]      -2.94       1.32

n_p = sum(x.numel() for x in model.parameters())  # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad)  # number gradients
aa = model.parameters()
bb = list(model.parameters())
print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))

aa是类似指针的一个东东

输出如下:
Model Summary: 222 layers, 6.1626e+07 parameters, 6.1626e+07 gradients

checkpoint = torch.load(model_path,map_location='cpu') 不加cpu会导致显存2倍!

# checkpoint = torch.load(model_path)
checkpoint = torch.load(model_path,map_location='cpu')
model.load_state_dict(checkpoint['state_dict'],strict=False)

改类别还需要继续finetune微调模型,一般只是最后一层由于类别数量对不上,那么就不加载和类别数有关的层就可以:

例子1

model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(path_model)

import collections
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
    name = k.replace('base.','')
    if 'aspp' in name:
        name = name + '_2'
    new_state_dict[name] = v


print("    Init:", CONFIG.MODEL.INIT_MODEL)
for m in model.base.state_dict().keys():
    if m not in new_state_dict.keys():
        print("    Skip init:", m)
model.base.load_state_dict(new_state_dict, strict=False)

例子2



pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth)))
# net.load_state_dict(pretrained_model['net'], strict=strict)

print("#######################################################################################################")
for name, parameters in net.named_parameters():
    print(name, ':', parameters.size())

d = OrderedDict()
for key, value in pretrained_model['net'].items():
    tmp = key[11:] ## del "module.net."
    d[tmp] = value

net.load_state_dict(d, strict=strict)
print("#######################################################################################################")

由于pth名字与model不一致,无法加载权重问题 new_state_dict = collections.OrderedDict() model.base.load_state_dict(new_state_dict, strict=False) # to skip ASPP

    model = DeepLabV3Plus_ResNet101_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)

    ####yhl################################################################
    import collections
    new_state_dict = collections.OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('base.', '')
        if 'aspp' in name:
            name = name + '_2'
        new_state_dict[name] = v
    ####################################################################

    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    for m in model.base.state_dict().keys():
        if m not in new_state_dict.keys():
            print("    Skip init:", m)
    model.base.load_state_dict(new_state_dict, strict=False)  # to skip ASPP
    model = nn.DataParallel(model)
    model.to(device)

打印网络层名和加载pth的名字,看是否是对应的。有时候会多出前缀“module.”

model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
    model = model.cuda()

for m in model.state_dict().keys():
     print("==:: ", m)

load_model_ = torch.load(model_path)
for k, v in load_model_.items():
    print(k,"  ::shape",v.shape)

加载模型,网络net需要module,而保存的pth里面的参数没有module,然后加上

        state = torch.load(resume_from, map_location=to_use_device)
        import collections
        d = collections.OrderedDict()
        for key, value in state['state_dict'].items():
            #tmp = key[7:]
            d["module." + key] = value

        _model.load_state_dict(d)

优化器optimizer设置

https://github.com/wuzuowuyou/DeepLabV3Plus-Pytorch/blob/master/main.py

# Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy=='poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy=='step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
# optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({'params': self.model.pretrained.parameters(), 'lr': args.lr})
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({'params': getattr(self.model, module).parameters(), 'lr': args.lr * 10})
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

冻结部分层参数

### //不训练某些层
  frozen_layers = [net.cnn, net.rnn, net.layer0, net.layer0_1]
  for layer in frozen_layers:
    for name, value in layer.named_parameters():
      value.requires_grad = False
  params = filter(lambda p: p.requires_grad, net.parameters())
def get_fine_tune_params(net, finetune_stage):
    """
    获取需要优化的参数
    Args:
        net:
    Returns: 需要优化的参数
    """

    # aa = net.backbone
    # aaa = net.module.backbone

    all_stage = ['backbone', 'neck', 'head']
    for stage_ in all_stage:
        if stage_ not in finetune_stage:
            stage_now = eval("net.module." + stage_)
            for name, value in stage_now.named_parameters():
                value.requires_grad = False
get_fine_tune_params(net, train_options['fine_tune_stage'])
# ===> solver and lr scheduler
optimizer = build_optimizer(filter(lambda p: p.requires_grad, net.parameters()), cfg['optimizer'])

检查是否冻结成功

    for name, param in net.named_parameters():
        if param.requires_grad:
            print(name)
posted @ 2020-08-20 15:29  无左无右  阅读(1126)  评论(0编辑  收藏  举报