pytorch model
- 网络定义
- model.named_children 返回名字 和 操作
- model.modules() 可用于参数初始化
- 其他的可以参考:
- model.parameters() || torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]
- 打印网络总参数量
- net.parameters() net.named_parameters() 显示网络参数
- checkpoint = torch.load(model_path,map_location='cpu') 不加cpu会导致显存2倍!
- 改类别还需要继续finetune微调模型,一般只是最后一层由于类别数量对不上,那么就不加载和类别数有关的层就可以:
- 由于pth名字与model不一致,无法加载权重问题 new_state_dict = collections.OrderedDict() model.base.load_state_dict(new_state_dict, strict=False) # to skip ASPP
- 打印网络层名和加载pth的名字,看是否是对应的。有时候会多出前缀“module.”
- 加载模型,网络net需要module,而保存的pth里面的参数没有module,然后加上
- 优化器optimizer设置
- 冻结部分层参数
网络定义
import torch as torch
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
layer1 = nn.Sequential()
layer1.add_module('conv1',nn.Conv2d(1,6,5))
layer1.add_module('pool1',nn.MaxPool2d(2,2))
self.layer1 = layer1
layer2 = nn.Sequential()
layer2.add_module('conv2',nn.Conv2d(6,16,5))
layer2.add_module('pool2',nn.MaxPool2d(2,2))
self.layer2 = layer2
layer3 = nn.Sequential()
layer3.add_module('fc1',nn.Linear(16*5*5,120))
layer3.add_module('fc2',nn.Linear(120,84))
layer3.add_module('fc3',nn.Linear(84,10))
self.layer3 = layer3
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = x.view(x.size(0),-1)#转换(降低)数据维度,进入全连接层
x = self.layer3(x)
return x
#代入数据检验
y = torch.randn(1,1,32,32)
model = LeNet()
out = model(y)
print(model)
print(out)
输出如下:
LeNet(
(layer1): Sequential(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer2): Sequential(
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer3): Sequential(
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
)
tensor([[ 0.0211, 0.1407, -0.1831, -0.1182, 0.0221, 0.1467, -0.0523, -0.0663,
-0.0351, -0.0434]], grad_fn=<AddmmBackward>)
def set_bn_momentum(model, momentum=0.1):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.momentum = momentum
def fix_bn(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
model.named_children 返回名字 和 操作
print("*"*50)
for name, module in model.named_children():
print(name)
print(module)
打印如下:
layer1
Sequential(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
layer2
Sequential(
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
layer3
Sequential(
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
可以用于forward,直接对输入遍历操作
def forward(self, x):
for name, module in self.named_children():
x = module(x)
model.modules() 可用于参数初始化
print("#"*200)
cnt = 0
for name in model.modules():
cnt += 1
print('-------------------------------------------------------cnt=',cnt)
print(name)
输出如下:
########################################################################################################################################################################################################
-------------------------------------------------------cnt= 1
LeNet(
(layer1): Sequential(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer2): Sequential(
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer3): Sequential(
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
)
-------------------------------------------------------cnt= 2
Sequential(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
-------------------------------------------------------cnt= 3
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
-------------------------------------------------------cnt= 4
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
-------------------------------------------------------cnt= 5
Sequential(
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
-------------------------------------------------------cnt= 6
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
-------------------------------------------------------cnt= 7
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
-------------------------------------------------------cnt= 8
Sequential(
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
-------------------------------------------------------cnt= 9
Linear(in_features=400, out_features=120, bias=True)
-------------------------------------------------------cnt= 10
Linear(in_features=120, out_features=84, bias=True)
-------------------------------------------------------cnt= 11
Linear(in_features=84, out_features=10, bias=True)
model.modules()用于参数初始化
cnt = 0
for name in model.modules():
cnt += 1
print('-------------------------------------------------------cnt=',cnt)
print(name)
if isinstance(name, nn.Conv2d):
print('------------------isinstance(name, nn.Conv2d)------------------')
print(name.weight)
print(name.bias)
print('--end----------------isinstance(name, nn.Conv2d)------------end------')
if isinstance(name, nn.Conv2d):
nn.init.kaiming_normal_(name.weight)
elif isinstance(name, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(name.weight, 1)
nn.init.constant_(name.bias, 0)
其中参数部分输出如下:
------------------isinstance(name, nn.Conv2d)------------------
Parameter containing:
tensor([[[[-0.1561, -0.0194, -0.0260, -0.0042, 0.1716],
[ 0.1181, -0.1380, -0.0448, 0.0674, -0.1972],
[-0.0197, 0.0359, 0.1186, 0.0876, -0.0395],
[-0.0619, 0.0095, -0.0702, 0.0122, 0.1573],
[ 0.1170, 0.1758, -0.1655, 0.1489, -0.0956]]],
...
[[[-0.1337, -0.0562, -0.0624, 0.0885, -0.0640],
[-0.0302, -0.1192, -0.0637, 0.0083, 0.0181],
[ 0.1388, -0.1690, 0.1132, 0.1686, -0.1189],
[-0.0246, -0.1649, -0.1817, -0.0330, -0.0430],
[ 0.0672, -0.0671, 0.0469, 0.1284, 0.1420]]]], requires_grad=True)
Parameter containing:
tensor([ 0.0548, 0.0547, 0.1328, -0.0452, 0.1668, -0.1915],
requires_grad=True)
--end----------------isinstance(name, nn.Conv2d)------------end------
model.modules()用于设置bn参数和冻结bn
def set_bn_momentum(model, momentum=0.1):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.momentum = momentum
def fix_bn(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
其他的可以参考:
https://blog.csdn.net/MrR1ght/article/details/105246412
model.children(): 返回模型的所有子模块的迭代器
model.modules():返回模型的所有模块(不仅仅是子模块,还包含当前模块)
model.named_children():返回当前子模块的迭代器。名字:模块
model.named_modules():
model.parameters() || torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]
参数:
params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float) – 学习率
momentum (float, 可选) – 动量因子(默认:0)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认:0)
dampening (float, 可选) – 动量的抑制因子(默认:0)
nesterov (bool, 可选) – 使用Nesterov动量(默认:False)
例子:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
这里对model.parameters()比较好奇
于是我打印:
print(model.parameters())
打印出这玩意:
<generator object Module.parameters at 0x7f1d2272d728>
感觉是一个指针,于是我在这样打印:
print(*model.parameters())
这回输出一大串数字:部分如下:
Parameter containing:
tensor([[[[-0.1751, 0.1829, 0.1973, 0.0780, 0.1220],
[-0.0497, 0.0943, 0.0827, 0.1829, 0.0239],
[-0.1044, 0.1268, 0.0716, -0.0100, 0.1991],
[-0.0730, 0.1762, -0.0787, 0.0686, -0.0069],
[ 0.1316, 0.0897, -0.1068, 0.0744, 0.0524]]],
[[[-0.1034, -0.1946, -0.1312, 0.1076, 0.0129],
[ 0.0450, 0.0552, 0.1448, -0.1283, -0.1868],
[-0.0260, -0.1928, 0.0519, -0.0493, -0.1028],
[-0.0936, 0.1719, -0.0997, 0.0008, 0.0871],
[ 0.0995, -0.1274, 0.0388, 0.0779, 0.0006]]],
[[[ 0.1846, -0.0723, 0.0649, -0.0169, -0.1595],
[ 0.0145, -0.1893, 0.0784, -0.0886, -0.0044],
[ 0.1914, -0.1009, -0.0736, -0.0992, -0.1618],
[-0.0291, 0.0997, 0.0549, 0.1267, -0.1661],
[-0.1333, 0.0168, 0.0648, 0.1047, -0.1506]]],
...
-4.0503e-03, 9.4014e-02, -8.5686e-02, 7.7082e-02]],
requires_grad=True) Parameter containing:
tensor([-0.0106, 0.0448, -0.0001, -0.0914, -0.0310, -0.0628, 0.0899, -0.0047,
-0.0390, -0.0291], requires_grad=True)
自定义参数
optimizer = torch.optim.SGD(params=[
{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
{'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
还看到另外的写法:
def get_1x_lr_params(self):
modules = [self.backbone]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def get_10x_lr_params(self):
modules = [self.aspp, self.decoder]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
{'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
# Define Optimizer
optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=args.nesterov)
打印网络总参数量
params = list(model.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
打印如下:
该层参数和:256
该层的结构:[256, 2048, 3, 3]
该层参数和:4718592
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 2048, 1, 1]
该层参数和:524288
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 1280, 1, 1]
该层参数和:327680
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[256, 304, 3, 3]
该层参数和:700416
该层的结构:[256]
该层参数和:256
该层的结构:[256]
该层参数和:256
该层的结构:[26, 256, 1, 1]
该层参数和:6656
该层的结构:[26]
该层参数和:26
总参数数量和:58755258
net.parameters() net.named_parameters() 显示网络参数
for parameters in net.parameters():
print(parameters)
输出如下:
Parameter containing:
tensor([[[[-0.0104, -0.0555, 0.1417],
[-0.3281, -0.0367, 0.0208],
[-0.0894, -0.0511, -0.1253]]],
[[[-0.1724, 0.2141, -0.0895],
[ 0.0116, 0.1661, -0.1853],
[-0.1190, 0.1292, -0.2451]]],
2
for name,parameters in net.named_parameters():
print(name,':',parameters.size())
输出如下:
module.backbone.conv1.weight : torch.Size([64, 3, 7, 7])
module.backbone.bn1.weight : torch.Size([64])
module.backbone.bn1.bias : torch.Size([64])
module.backbone.layer1.0.conv1.weight : torch.Size([64, 64, 1, 1])
module.backbone.layer1.0.bn1.weight : torch.Size([64])
module.backbone.layer1.0.bn1.bias : torch.Size([64])
module.backbone.layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.0.bn2.weight : torch.Size([64])
module.backbone.layer1.0.bn2.bias : torch.Size([64])
module.backbone.layer1.0.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.0.bn3.weight : torch.Size([256])
module.backbone.layer1.0.bn3.bias : torch.Size([256])
module.backbone.layer1.0.downsample.0.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.0.downsample.1.weight : torch.Size([256])
module.backbone.layer1.0.downsample.1.bias : torch.Size([256])
module.backbone.layer1.1.conv1.weight : torch.Size([64, 256, 1, 1])
module.backbone.layer1.1.bn1.weight : torch.Size([64])
module.backbone.layer1.1.bn1.bias : torch.Size([64])
module.backbone.layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.1.bn2.weight : torch.Size([64])
module.backbone.layer1.1.bn2.bias : torch.Size([64])
module.backbone.layer1.1.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.1.bn3.weight : torch.Size([256])
module.backbone.layer1.1.bn3.bias : torch.Size([256])
module.backbone.layer1.2.conv1.weight : torch.Size([64, 256, 1, 1])
module.backbone.layer1.2.bn1.weight : torch.Size([64])
module.backbone.layer1.2.bn1.bias : torch.Size([64])
module.backbone.layer1.2.conv2.weight : torch.Size([64, 64, 3, 3])
module.backbone.layer1.2.bn2.weight : torch.Size([64])
module.backbone.layer1.2.bn2.bias : torch.Size([64])
module.backbone.layer1.2.conv3.weight : torch.Size([256, 64, 1, 1])
module.backbone.layer1.2.bn3.weight : torch.Size([256])
module.backbone.layer1.2.bn3.bias : torch.Size([256])
module.backbone.layer2.0.conv1.weight : torch.Size([128, 256, 1, 1])
module.backbone.layer2.0.bn1.weight : torch.Size([128])
module.backbone.layer2.0.bn1.bias : torch.Size([128])
module.backbone.layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
module.backbone.layer2.0.bn2.weight : torch.Size([128])
module.backbone.layer2.0.bn2.bias : torch.Size([128])
module.backbone.layer2.0.conv3.weight : torch.Size([512, 128, 1, 1])
module.backbone.layer2.0.bn3.weight : torch.Size([512])
module.backbone.layer2.0.bn3.bias : torch.Size([512])
2021.01.21补充例子:
if verbose:
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
输出如下:
layer name gradient parameters shape mu sigma
0 0.Conv2d.weight True 864 [32, 3, 3, 3] -8.67e-05 0.112
1 0.BatchNorm2d.weight True 32 [32] 1 0
2 0.BatchNorm2d.bias True 32 [32] 0 0
3 1.Conv2d.weight True 18432 [64, 32, 3, 3] 0.000242 0.034
4 1.BatchNorm2d.weight True 64 [64] 1 0
5 1.BatchNorm2d.bias True 64 [64] 0 0
6 2.Conv2d.weight True 2048 [32, 64, 1, 1] 0.000321 0.072
7 2.BatchNorm2d.weight True 32 [32] 1 0
8 2.BatchNorm2d.bias True 32 [32] 0 0
9 3.Conv2d.weight True 18432 [64, 32, 3, 3] 0.000188 0.0339
10 3.BatchNorm2d.weight True 64 [64] 1 0
11 3.BatchNorm2d.bias True 64 [64] 0 0
12 5.Conv2d.weight True 73728 [128, 64, 3, 3] -2.89e-05 0.0241
13 5.BatchNorm2d.weight True 128 [128] 1 0
14 5.BatchNorm2d.bias True 128 [128] 0 0
15 6.Conv2d.weight True 8192 [64, 128, 1, 1] 0.000717 0.0513
16 6.BatchNorm2d.weight True 64 [64] 1 0
17 6.BatchNorm2d.bias True 64 [64] 0 0
18 7.Conv2d.weight True 73728 [128, 64, 3, 3] 0.000112 0.0241
19 7.BatchNorm2d.weight True 128 [128] 1 0
20 7.BatchNorm2d.bias True 128 [128] 0 0
21 9.Conv2d.weight True 8192 [64, 128, 1, 1] -0.000266 0.051
22 9.BatchNorm2d.weight True 64 [64] 1 0
23 9.BatchNorm2d.bias True 64 [64] 0 0
24 10.Conv2d.weight True 73728 [128, 64, 3, 3] -0.0002 0.0241
25 10.BatchNorm2d.weight True 128 [128] 1 0
26 10.BatchNorm2d.bias True 128 [128] 0 0
27 12.Conv2d.weight True 294912 [256, 128, 3, 3] 4.95e-05 0.017
28 12.BatchNorm2d.weight True 256 [256] 1 0
29 12.BatchNorm2d.bias True 256 [256] 0 0
30 13.Conv2d.weight True 32768 [128, 256, 1, 1] -9.5e-05 0.036
31 13.BatchNorm2d.weight True 128 [128] 1 0
32 13.BatchNorm2d.bias True 128 [128] 0 0
33 14.Conv2d.weight True 294912 [256, 128, 3, 3] -2.46e-06 0.017
34 14.BatchNorm2d.weight True 256 [256] 1 0
35 14.BatchNorm2d.bias True 256 [256] 0 0
36 16.Conv2d.weight True 32768 [128, 256, 1, 1] 0.000225 0.0359
37 16.BatchNorm2d.weight True 128 [128] 1 0
38 16.BatchNorm2d.bias True 128 [128] 0 0
39 17.Conv2d.weight True 294912 [256, 128, 3, 3] -5.76e-05 0.017
40 17.BatchNorm2d.weight True 256 [256] 1 0
41 17.BatchNorm2d.bias True 256 [256] 0 0
42 19.Conv2d.weight True 32768 [128, 256, 1, 1] -6.58e-05 0.036
43 19.BatchNorm2d.weight True 128 [128] 1 0
44 19.BatchNorm2d.bias True 128 [128] 0 0
45 20.Conv2d.weight True 294912 [256, 128, 3, 3] -1.72e-06 0.017
46 20.BatchNorm2d.weight True 256 [256] 1 0
47 20.BatchNorm2d.bias True 256 [256] 0 0
48 22.Conv2d.weight True 32768 [128, 256, 1, 1] 0.000157 0.036
49 22.BatchNorm2d.weight True 128 [128] 1 0
50 22.BatchNorm2d.bias True 128 [128] 0 0
51 23.Conv2d.weight True 294912 [256, 128, 3, 3] 2.92e-05 0.017
52 23.BatchNorm2d.weight True 256 [256] 1 0
53 23.BatchNorm2d.bias True 256 [256] 0 0
54 25.Conv2d.weight True 32768 [128, 256, 1, 1] 0.000226 0.0361
55 25.BatchNorm2d.weight True 128 [128] 1 0
56 25.BatchNorm2d.bias True 128 [128] 0 0
57 26.Conv2d.weight True 294912 [256, 128, 3, 3] 1.74e-05 0.017
58 26.BatchNorm2d.weight True 256 [256] 1 0
59 26.BatchNorm2d.bias True 256 [256] 0 0
60 28.Conv2d.weight True 32768 [128, 256, 1, 1] 0.000182 0.036
61 28.BatchNorm2d.weight True 128 [128] 1 0
62 28.BatchNorm2d.bias True 128 [128] 0 0
63 29.Conv2d.weight True 294912 [256, 128, 3, 3] 5.26e-07 0.017
64 29.BatchNorm2d.weight True 256 [256] 1 0
65 29.BatchNorm2d.bias True 256 [256] 0 0
66 31.Conv2d.weight True 32768 [128, 256, 1, 1] -0.000297 0.0361
67 31.BatchNorm2d.weight True 128 [128] 1 0
68 31.BatchNorm2d.bias True 128 [128] 0 0
69 32.Conv2d.weight True 294912 [256, 128, 3, 3] 4.21e-05 0.017
70 32.BatchNorm2d.weight True 256 [256] 1 0
71 32.BatchNorm2d.bias True 256 [256] 0 0
72 34.Conv2d.weight True 32768 [128, 256, 1, 1] 2.84e-05 0.036
73 34.BatchNorm2d.weight True 128 [128] 1 0
74 34.BatchNorm2d.bias True 128 [128] 0 0
75 35.Conv2d.weight True 294912 [256, 128, 3, 3] -4.58e-05 0.017
76 35.BatchNorm2d.weight True 256 [256] 1 0
77 35.BatchNorm2d.bias True 256 [256] 0 0
78 37.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 2.59e-06 0.012
79 37.BatchNorm2d.weight True 512 [512] 1 0
80 37.BatchNorm2d.bias True 512 [512] 0 0
81 38.Conv2d.weight True 131072 [256, 512, 1, 1] -2.42e-05 0.0255
82 38.BatchNorm2d.weight True 256 [256] 1 0
83 38.BatchNorm2d.bias True 256 [256] 0 0
84 39.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 6.23e-06 0.012
85 39.BatchNorm2d.weight True 512 [512] 1 0
86 39.BatchNorm2d.bias True 512 [512] 0 0
87 41.Conv2d.weight True 131072 [256, 512, 1, 1] 3.8e-05 0.0255
88 41.BatchNorm2d.weight True 256 [256] 1 0
89 41.BatchNorm2d.bias True 256 [256] 0 0
90 42.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -1.15e-05 0.012
91 42.BatchNorm2d.weight True 512 [512] 1 0
92 42.BatchNorm2d.bias True 512 [512] 0 0
93 44.Conv2d.weight True 131072 [256, 512, 1, 1] 1.25e-07 0.0254
94 44.BatchNorm2d.weight True 256 [256] 1 0
95 44.BatchNorm2d.bias True 256 [256] 0 0
96 45.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -1.02e-05 0.012
97 45.BatchNorm2d.weight True 512 [512] 1 0
98 45.BatchNorm2d.bias True 512 [512] 0 0
99 47.Conv2d.weight True 131072 [256, 512, 1, 1] 0.00018 0.0255
100 47.BatchNorm2d.weight True 256 [256] 1 0
101 47.BatchNorm2d.bias True 256 [256] 0 0
102 48.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -1.22e-05 0.012
103 48.BatchNorm2d.weight True 512 [512] 1 0
104 48.BatchNorm2d.bias True 512 [512] 0 0
105 50.Conv2d.weight True 131072 [256, 512, 1, 1] -2.25e-05 0.0255
106 50.BatchNorm2d.weight True 256 [256] 1 0
107 50.BatchNorm2d.bias True 256 [256] 0 0
108 51.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 6.82e-06 0.012
109 51.BatchNorm2d.weight True 512 [512] 1 0
110 51.BatchNorm2d.bias True 512 [512] 0 0
111 53.Conv2d.weight True 131072 [256, 512, 1, 1] -6.9e-05 0.0255
112 53.BatchNorm2d.weight True 256 [256] 1 0
113 53.BatchNorm2d.bias True 256 [256] 0 0
114 54.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 1.89e-06 0.012
115 54.BatchNorm2d.weight True 512 [512] 1 0
116 54.BatchNorm2d.bias True 512 [512] 0 0
117 56.Conv2d.weight True 131072 [256, 512, 1, 1] 0.00015 0.0255
118 56.BatchNorm2d.weight True 256 [256] 1 0
119 56.BatchNorm2d.bias True 256 [256] 0 0
120 57.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 2.61e-05 0.012
121 57.BatchNorm2d.weight True 512 [512] 1 0
122 57.BatchNorm2d.bias True 512 [512] 0 0
123 59.Conv2d.weight True 131072 [256, 512, 1, 1] -0.000128 0.0256
124 59.BatchNorm2d.weight True 256 [256] 1 0
125 59.BatchNorm2d.bias True 256 [256] 0 0
126 60.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -1.97e-06 0.012
127 60.BatchNorm2d.weight True 512 [512] 1 0
128 60.BatchNorm2d.bias True 512 [512] 0 0
129 62.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] 1.53e-06 0.0085
130 62.BatchNorm2d.weight True 1024 [1024] 1 0
131 62.BatchNorm2d.bias True 1024 [1024] 0 0
132 63.Conv2d.weight True 524288 [512, 1024, 1, 1] -1.84e-05 0.0181
133 63.BatchNorm2d.weight True 512 [512] 1 0
134 63.BatchNorm2d.bias True 512 [512] 0 0
135 64.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] 2.17e-07 0.0085
136 64.BatchNorm2d.weight True 1024 [1024] 1 0
137 64.BatchNorm2d.bias True 1024 [1024] 0 0
138 66.Conv2d.weight True 524288 [512, 1024, 1, 1] 2.39e-05 0.018
139 66.BatchNorm2d.weight True 512 [512] 1 0
140 66.BatchNorm2d.bias True 512 [512] 0 0
141 67.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] -1.41e-06 0.00851
142 67.BatchNorm2d.weight True 1024 [1024] 1 0
143 67.BatchNorm2d.bias True 1024 [1024] 0 0
144 69.Conv2d.weight True 524288 [512, 1024, 1, 1] -1.94e-05 0.018
145 69.BatchNorm2d.weight True 512 [512] 1 0
146 69.BatchNorm2d.bias True 512 [512] 0 0
147 70.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] 1.07e-06 0.00851
148 70.BatchNorm2d.weight True 1024 [1024] 1 0
149 70.BatchNorm2d.bias True 1024 [1024] 0 0
150 72.Conv2d.weight True 524288 [512, 1024, 1, 1] 3.62e-05 0.0181
151 72.BatchNorm2d.weight True 512 [512] 1 0
152 72.BatchNorm2d.bias True 512 [512] 0 0
153 73.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] 4.51e-06 0.0085
154 73.BatchNorm2d.weight True 1024 [1024] 1 0
155 73.BatchNorm2d.bias True 1024 [1024] 0 0
156 75.Conv2d.weight True 524288 [512, 1024, 1, 1] 2.73e-05 0.018
157 75.BatchNorm2d.weight True 512 [512] 1 0
158 75.BatchNorm2d.bias True 512 [512] 0 0
159 76.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] 2.64e-06 0.0085
160 76.BatchNorm2d.weight True 1024 [1024] 1 0
161 76.BatchNorm2d.bias True 1024 [1024] 0 0
162 77.Conv2d.weight True 524288 [512, 1024, 1, 1] -3.97e-05 0.018
163 77.BatchNorm2d.weight True 512 [512] 1 0
164 77.BatchNorm2d.bias True 512 [512] 0 0
165 78.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] -7.67e-08 0.00851
166 78.BatchNorm2d.weight True 1024 [1024] 1 0
167 78.BatchNorm2d.bias True 1024 [1024] 0 0
168 79.Conv2d.weight True 524288 [512, 1024, 1, 1] -3.87e-05 0.018
169 79.BatchNorm2d.weight True 512 [512] 1 0
170 79.BatchNorm2d.bias True 512 [512] 0 0
171 80.Conv2d.weight True 4.71859e+06 [1024, 512, 3, 3] -3.03e-06 0.00851
172 80.BatchNorm2d.weight True 1024 [1024] 1 0
173 80.BatchNorm2d.bias True 1024 [1024] 0 0
174 81.Conv2d.weight True 76800 [75, 1024, 1, 1] -7.83e-06 0.0181
175 81.Conv2d.bias True 75 [75] -2.94 1.31
176 84.Conv2d.weight True 131072 [256, 512, 1, 1] -7.19e-05 0.0255
177 84.BatchNorm2d.weight True 256 [256] 1 0
178 84.BatchNorm2d.bias True 256 [256] 0 0
179 87.Conv2d.weight True 196608 [256, 768, 1, 1] -7.45e-06 0.0208
180 87.BatchNorm2d.weight True 256 [256] 1 0
181 87.BatchNorm2d.bias True 256 [256] 0 0
182 88.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -4.27e-06 0.012
183 88.BatchNorm2d.weight True 512 [512] 1 0
184 88.BatchNorm2d.bias True 512 [512] 0 0
185 89.Conv2d.weight True 131072 [256, 512, 1, 1] 4.33e-05 0.0255
186 89.BatchNorm2d.weight True 256 [256] 1 0
187 89.BatchNorm2d.bias True 256 [256] 0 0
188 90.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] 2.65e-06 0.012
189 90.BatchNorm2d.weight True 512 [512] 1 0
190 90.BatchNorm2d.bias True 512 [512] 0 0
191 91.Conv2d.weight True 131072 [256, 512, 1, 1] -4.59e-05 0.0255
192 91.BatchNorm2d.weight True 256 [256] 1 0
193 91.BatchNorm2d.bias True 256 [256] 0 0
194 92.Conv2d.weight True 1.17965e+06 [512, 256, 3, 3] -3.91e-06 0.012
195 92.BatchNorm2d.weight True 512 [512] 1 0
196 92.BatchNorm2d.bias True 512 [512] 0 0
197 93.Conv2d.weight True 38400 [75, 512, 1, 1] 0.000177 0.0255
198 93.Conv2d.bias True 75 [75] -2.94 1.3
199 96.Conv2d.weight True 32768 [128, 256, 1, 1] -0.000275 0.036
200 96.BatchNorm2d.weight True 128 [128] 1 0
201 96.BatchNorm2d.bias True 128 [128] 0 0
202 99.Conv2d.weight True 49152 [128, 384, 1, 1] 9.75e-05 0.0295
203 99.BatchNorm2d.weight True 128 [128] 1 0
204 99.BatchNorm2d.bias True 128 [128] 0 0
205 100.Conv2d.weight True 294912 [256, 128, 3, 3] -1.5e-06 0.017
206 100.BatchNorm2d.weight True 256 [256] 1 0
207 100.BatchNorm2d.bias True 256 [256] 0 0
208 101.Conv2d.weight True 32768 [128, 256, 1, 1] -6.49e-05 0.036
209 101.BatchNorm2d.weight True 128 [128] 1 0
210 101.BatchNorm2d.bias True 128 [128] 0 0
211 102.Conv2d.weight True 294912 [256, 128, 3, 3] 1.01e-05 0.017
212 102.BatchNorm2d.weight True 256 [256] 1 0
213 102.BatchNorm2d.bias True 256 [256] 0 0
214 103.Conv2d.weight True 32768 [128, 256, 1, 1] 0.000229 0.036
215 103.BatchNorm2d.weight True 128 [128] 1 0
216 103.BatchNorm2d.bias True 128 [128] 0 0
217 104.Conv2d.weight True 294912 [256, 128, 3, 3] -1.62e-05 0.017
218 104.BatchNorm2d.weight True 256 [256] 1 0
219 104.BatchNorm2d.bias True 256 [256] 0 0
220 105.Conv2d.weight True 19200 [75, 256, 1, 1] 0.00016 0.0361
221 105.Conv2d.bias True 75 [75] -2.94 1.32
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
aa = model.parameters()
bb = list(model.parameters())
print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
aa是类似指针的一个东东
输出如下:
Model Summary: 222 layers, 6.1626e+07 parameters, 6.1626e+07 gradients
checkpoint = torch.load(model_path,map_location='cpu') 不加cpu会导致显存2倍!
# checkpoint = torch.load(model_path)
checkpoint = torch.load(model_path,map_location='cpu')
model.load_state_dict(checkpoint['state_dict'],strict=False)
改类别还需要继续finetune微调模型,一般只是最后一层由于类别数量对不上,那么就不加载和类别数有关的层就可以:
例子1
model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(path_model)
import collections
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
name = k.replace('base.','')
if 'aspp' in name:
name = name + '_2'
new_state_dict[name] = v
print(" Init:", CONFIG.MODEL.INIT_MODEL)
for m in model.base.state_dict().keys():
if m not in new_state_dict.keys():
print(" Skip init:", m)
model.base.load_state_dict(new_state_dict, strict=False)
例子2
pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth)))
# net.load_state_dict(pretrained_model['net'], strict=strict)
print("#######################################################################################################")
for name, parameters in net.named_parameters():
print(name, ':', parameters.size())
d = OrderedDict()
for key, value in pretrained_model['net'].items():
tmp = key[11:] ## del "module.net."
d[tmp] = value
net.load_state_dict(d, strict=strict)
print("#######################################################################################################")
由于pth名字与model不一致,无法加载权重问题 new_state_dict = collections.OrderedDict() model.base.load_state_dict(new_state_dict, strict=False) # to skip ASPP
model = DeepLabV3Plus_ResNet101_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)
####yhl################################################################
import collections
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
name = k.replace('base.', '')
if 'aspp' in name:
name = name + '_2'
new_state_dict[name] = v
####################################################################
print(" Init:", CONFIG.MODEL.INIT_MODEL)
for m in model.base.state_dict().keys():
if m not in new_state_dict.keys():
print(" Skip init:", m)
model.base.load_state_dict(new_state_dict, strict=False) # to skip ASPP
model = nn.DataParallel(model)
model.to(device)
打印网络层名和加载pth的名字,看是否是对应的。有时候会多出前缀“module.”
model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
model = model.cuda()
for m in model.state_dict().keys():
print("==:: ", m)
load_model_ = torch.load(model_path)
for k, v in load_model_.items():
print(k," ::shape",v.shape)
加载模型,网络net需要module,而保存的pth里面的参数没有module,然后加上
state = torch.load(resume_from, map_location=to_use_device)
import collections
d = collections.OrderedDict()
for key, value in state['state_dict'].items():
#tmp = key[7:]
d["module." + key] = value
_model.load_state_dict(d)
优化器optimizer设置
https://github.com/wuzuowuyou/DeepLabV3Plus-Pytorch/blob/master/main.py
# Set up optimizer
optimizer = torch.optim.SGD(params=[
{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
{'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
#optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
#torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
if opts.lr_policy=='poly':
scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
elif opts.lr_policy=='step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
# optimizer, for model just includes pretrained, head and auxlayer
params_list = list()
if hasattr(self.model, 'pretrained'):
params_list.append({'params': self.model.pretrained.parameters(), 'lr': args.lr})
if hasattr(self.model, 'exclusive'):
for module in self.model.exclusive:
params_list.append({'params': getattr(self.model, module).parameters(), 'lr': args.lr * 10})
self.optimizer = torch.optim.SGD(params_list,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# lr scheduling
self.lr_scheduler = WarmupPolyLR(self.optimizer,
max_iters=args.max_iters,
power=0.9,
warmup_factor=args.warmup_factor,
warmup_iters=args.warmup_iters,
warmup_method=args.warmup_method)
冻结部分层参数
### //不训练某些层
frozen_layers = [net.cnn, net.rnn, net.layer0, net.layer0_1]
for layer in frozen_layers:
for name, value in layer.named_parameters():
value.requires_grad = False
params = filter(lambda p: p.requires_grad, net.parameters())
def get_fine_tune_params(net, finetune_stage):
"""
获取需要优化的参数
Args:
net:
Returns: 需要优化的参数
"""
# aa = net.backbone
# aaa = net.module.backbone
all_stage = ['backbone', 'neck', 'head']
for stage_ in all_stage:
if stage_ not in finetune_stage:
stage_now = eval("net.module." + stage_)
for name, value in stage_now.named_parameters():
value.requires_grad = False
get_fine_tune_params(net, train_options['fine_tune_stage'])
# ===> solver and lr scheduler
optimizer = build_optimizer(filter(lambda p: p.requires_grad, net.parameters()), cfg['optimizer'])
检查是否冻结成功
for name, param in net.named_parameters():
if param.requires_grad:
print(name)