代码笔记11 PyTorch加载部分模型参数到另一个模型
1
首先,加载是有条件的,就是两个模型想要加载的参数的名字相同,才能加载进来。
2
加载的方法有两种,load_state_dict(strict)和update
代码里的方法是对backbone单独做一个Module类,这样不容易出错。
代码中展示了如何把train_net中的backbone参数加载到test_net中的两种办法
import torch
import torch.nn as nn
import torch.nn.functional as F
class backbone(nn.Module):
def __init__(self):
super(backbone, self).__init__()
self.backbone_conv1 = nn.Conv2d(in_channels=1,out_channels=3,kernel_size=2,stride=1,padding=1)
self.normal1 = nn.GroupNorm(num_groups=1,num_channels=3)
self.backbone_conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=2, stride=1, padding=1)
self.normal2 = nn.GroupNorm(num_groups=1, num_channels=3)
def forward(self,input):
#Stage 1
conv1 = self.normal1(self.backbone_conv1(input))
pool1,id1 = F.max_pool2d(F.relu(conv1),kernel_size=2, stride=2, return_indices=True)
#Stage2
conv2 = self.normal2(self.backbone_conv1(pool1))
pool2,id2 = F.max_pool2d(F.relu(conv2),kernel_size=2, stride=2, return_indices=True)
return pool2,id2
class train(nn.Module):
def __init__(self):
super(train, self).__init__()
self.backbone_RGB = backbone()
self.train_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=2,stride=1,padding=0)
self.train_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)
self.train_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)
def forward(self,input):
#Stage1
train_input,id = self.backbone_RGB(input)
x1 = self.train_conv1(train_input)
x2 = self.train_conv2(x1)
x3 = self.train_conv3(x2)
return x3
class test(nn.Module):
def __init__(self):
super(test,self).__init__()
self.backbone_RGB = backbone()
self.test_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=0)
self.test_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)
self.test_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)
def forward(self,input):
#Stage1
train_input,id = self.backbone_RGB(input)
x1 = self.train_conv1(train_input)
x2 = self.train_conv2(x1)
x3 = self.train_conv3(x2)
return x3
backbone_net = backbone()
for name,parameter in backbone_net.state_dict().items():
print(name)
print("------------------------------------------------------------")
train_net = train()
for name,parameter in train_net.state_dict().items():
print(name)
print(parameter)
print("------------------------------------------------------------")
test_net = test()
for name,parameter in test_net.state_dict().items():
print(name)
print(parameter)
#method1 load_state_dict(strict)
test_net.load_state_dict(train_net.state_dict(), strict=False) #set strict to False,for loading the same name parameters
for name,parameter in test_net.state_dict().items():
print(name)
print(parameter)
#method2 update
model_dict=test_net.state_dict() #load test_net names and paras
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in train_net.state_dict().items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
test_net.load_state_dict(model_dict)
具体的也可以看看这几个博客,我也是从这里面学的[1],包括如何冻结参数[2]
Refrences
[1]https://blog.csdn.net/qq_41314786/article/details/112569854
[2]https://blog.csdn.net/weixin_44815943/article/details/113180588?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-113180588-blog-112569854.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3