在pytorch中进行预训练模型的加载和模型的fine-tune操作

联系方式:
e-mail: FesianXu@163.com
QQ: 973926198
github: https://github.com/FesianXu

如有谬误,请联系指正。转载请注明出处



我们在使用pytorch的时候,经常有需要使用一些通用的模型模块作为子模块,比如著名的resnet,densenet,alexnet,inception等等,在使用这些模型的时候,通常我们希望可以加载该模型在别的数据集,如ImageNet上进行训练后的权值参数,以便加速整个模型训练过程[1]。在此为了简便,称之为这些预训练模型为基模型,加载基模型参数这个过程按照需求大致可以分为两类:

  1. 整个模型的预训练参数加载
  2. 部分模型的预训练参数加载

在对模型的参数进行fine-tune[2-3]的时候,按照需求也可以大致分为两类:

  1. 固定整个基模型的参数,调节其他模型的参数
  2. 固定部分基模型的参数,调节其他模型的参数

我们接下来基于pytorch框架[4]对其进行讨论。

基模型参数加载

从持久化模型开始

在pytorch中,保存一个模型的参数特别容易,用torch.save()即可,例如:

model = CNNNet(params)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
# here we train the models, skip these codes
saved_dict = {
	'model': model.state_dict(),
	'opt': opt.state_dict()
}
torch.save(saved_dict, './model.pth.tar')

我们发现,torch.save()保存的是一个字典,其中的keys可以自定义。这里有一点要注意的是,如果你用的优化器是例如Adam优化器[5-6]这类内部有参数需要持久化的,最好也将其保存下来。

加载模型吧

如果是需要加载整个模型,直接用torch.load()model.load_state_dict()即可,如:

model = CNNNet(params)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
# yes you also need to define the model and optimizer
checkpoint = torch.load('./model.pth.tar')
# here, checkpoint is a dict with the keys you defined before
model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['opt'])

这个过程中torch.load()只是负责读取模型参数,而用model.load_state_dict()进行加载,这个加载是按照名字进行索引的,如果名字对不上或者是参数的形状,类型对不上,就会报错。我们可以打印出其名字进行观察,如:

for name in checkpoint ['model'].keys():
    print(name)

输出如:

stgcn.weight_model.conv_models.0.conv_layer.weight
stgcn.weight_model.conv_models.0.conv_layer.bias
stgcn.weight_model.conv_models.1.conv_layer.weight
stgcn.weight_model.conv_models.1.conv_layer.bias
stgcn.weight_model.conv_models.1.batch_norm.weight
stgcn.weight_model.conv_models.1.batch_norm.bias

如果定义的模型和持久化的模型的参数名,形状,类型都能完全符合,就能正确加载。我们同时还注意到,变量的名字是由pytorch自行命名的,其命名根据就是你的变量名字,比如:

import torch.nn as nn
class model_A(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10,10)
        
class model_B(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub_model = model_A()
        self.fc = nn.Linear(10,1)
model = model_B()

那么如果你打印prin(model_B),你就会发现子模块的名字为

sub_model.fc.weight
sub_model.fc.bias
fc.weight
fc.bias

我们观察到名字是以你的变量标识符命名的,这点和TensorFlow的命名机制完全不同,请注意。
同时我们还观察到,只要是权值weight其命名后缀都是weight,同样偏置bias的后缀是bias,因此根据此,可以单独对权值进行L2正则[7],具体过程见[8]。

部分加载模型

根据上面的分析,我们便发现只要过滤掉不需要加载的模型的名字,即可实现部分模型加载了,例子如:

model = CNNNet()
checkpoint = torch.load('./model.pth.tar')
for name, params in model.stgcn.st_gcn_networks.named_parameters():
	 params_name = 'stgcn.st_gcn_networks.'+name
    if params_name in model.state_dict():
    	model.state_dict()[params_name].copy_(checkpoint['model'][params_name])

我们发现,通过这个代码,我们可以仅对model.stgcn.st_gcn_networks这个子模块的参数进行加载,而其他的参数保持初始化情况不变。

模型Fine-Tune

在模型的Fine-Tune(微调)或者联合调试过程中,我们经常需要固定某个模型的参数,而去调整其他模型的参数,主要方法有两个:

  1. 通过切断某个模块的梯度流,但是这个会导致该模型前面的所有模型也没有梯度。
  2. 通过设置某个模块的所有变量的requires_grad=False
  3. 在优化器内设置需要进行梯度更新的变量。

笔者在实践过程中最常用的是第三种方法,暂时只介绍第三种方法。代码很简单,例子如下:

trainable_vars = list(model.stgcn.weight_model.parameters())+ \
                 list(model.stgcn.fcn.parameters())+ \
                 list(model.stgcn.data_bn.parameters())+ \
                 list(model.stgcn.dim_map.parameters())+ \
                 list(model.aux_cls.parameters())                
opt = torch.optim.SGD(trainable_vars , lr=1e-4, momentum=0.9)

简单粗暴,但是其实很好用,当需要训练的变量很多,而需要固定的变量很少的时候,可以用对整个模型参数求补的方式求得,这里不多介绍了。
当需要对整个模型进行微调时,只需要:

opt = torch.optim.SGD(model.parameters(), lr=1e-6, momentum=0.9)

给每一层或者每个模型设置不同的学习率

在模型训练过程中,有些模块,比如对抗生成网络GAN[9]的生成器和判别器经常需要设置不同的学习率,以求得更好的效果或者不同模型之间的平衡。详细内容见我以前文章[8]中所述,这里不再累述。

Pytorch内置的模型

pytorch内置有一些经常使用的模型和其在大规模数据集上的预训练参数,只需要安装了torchvision便可轻松使用,模型有:

  1. resnet: resnet18,resnet34,resnet50,resnet101,resnet152
  2. vgg: vgg11, vgg13, vgg16, vgg19
  3. alexnet
  4. densenet
  5. inception
  6. squeezenet

具体模型定义见:Github click me
使用方法很简单,如:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=False)

在这里如果指定pretrained=True可以联网加载预训练模型,但是由于大陆因为你懂得原因,所以需要你懂得辅助工具,建议读者自行去下载模型的文件后手动加载。模型文件的地址可以在模型定义文件中找到,如[https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py]中的resnet模型:

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

Reference

[1]. Kaiming He, Ross Girshick, Piotr Dollár. Rethinking ImageNet Pre-training[J]. arXiv preprint, https://arxiv.org/abs/1811.08883
[2]. 迁移学习与fine-tuning有什么区别?
[3]. Fine tuning
[4]. pytorch
[5]. Adam 算法
[6]. Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.
[7]. 曲线拟合问题与L2正则
[8]. pytorch中的L2和L1正则化,自定义优化器设置等操作
[9]. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.

posted @ 2018-12-13 10:13  FesianXu  阅读(121)  评论(0编辑  收藏  举报