代码笔记11 PyTorch加载部分模型参数到另一个模型

1

 首先,加载是有条件的,就是两个模型想要加载的参数的名字相同,才能加载进来。

2

 加载的方法有两种,load_state_dict(strict)和update
代码里的方法是对backbone单独做一个Module类,这样不容易出错。
代码中展示了如何把train_net中的backbone参数加载到test_net中的两种办法

import torch
import torch.nn as nn
import torch.nn.functional as F


class backbone(nn.Module):
    def __init__(self):
        super(backbone, self).__init__()
        self.backbone_conv1 = nn.Conv2d(in_channels=1,out_channels=3,kernel_size=2,stride=1,padding=1)
        self.normal1 = nn.GroupNorm(num_groups=1,num_channels=3)
        self.backbone_conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=2, stride=1, padding=1)
        self.normal2 = nn.GroupNorm(num_groups=1, num_channels=3)
    def forward(self,input):
        #Stage 1
        conv1 = self.normal1(self.backbone_conv1(input))
        pool1,id1 = F.max_pool2d(F.relu(conv1),kernel_size=2, stride=2, return_indices=True)

        #Stage2
        conv2 = self.normal2(self.backbone_conv1(pool1))
        pool2,id2 = F.max_pool2d(F.relu(conv2),kernel_size=2, stride=2, return_indices=True)

        return pool2,id2

class train(nn.Module):
    def __init__(self):
        super(train, self).__init__()
        self.backbone_RGB = backbone()
        self.train_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=2,stride=1,padding=0)
        self.train_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)
        self.train_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)

    def forward(self,input):
        #Stage1
        train_input,id = self.backbone_RGB(input)
        x1 = self.train_conv1(train_input)
        x2 = self.train_conv2(x1)
        x3 = self.train_conv3(x2)

        return x3

class test(nn.Module):
    def __init__(self):
        super(test,self).__init__()
        self.backbone_RGB = backbone()
        self.test_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=0)
        self.test_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)
        self.test_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)

    def forward(self,input):
        #Stage1
        train_input,id = self.backbone_RGB(input)
        x1 = self.train_conv1(train_input)
        x2 = self.train_conv2(x1)
        x3 = self.train_conv3(x2)

        return x3

backbone_net = backbone()
for name,parameter in backbone_net.state_dict().items():
    print(name)


print("------------------------------------------------------------")

train_net = train()
for name,parameter in train_net.state_dict().items():
    print(name)
    print(parameter)


print("------------------------------------------------------------")


test_net = test()
for name,parameter in test_net.state_dict().items():
    print(name)
    print(parameter)

#method1 load_state_dict(strict)

test_net.load_state_dict(train_net.state_dict(), strict=False) #set strict to False,for loading the same name parameters

for name,parameter in test_net.state_dict().items():
    print(name)
    print(parameter)

#method2 update

model_dict=test_net.state_dict()   #load test_net names and paras
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in train_net.state_dict().items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
test_net.load_state_dict(model_dict)


具体的也可以看看这几个博客,我也是从这里面学的[1],包括如何冻结参数[2]

Refrences

[1]https://blog.csdn.net/qq_41314786/article/details/112569854
[2]https://blog.csdn.net/weixin_44815943/article/details/113180588?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-113180588-blog-112569854.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3

posted @ 2022-05-25 23:16  The1912  阅读(983)  评论(0编辑  收藏  举报