Pytorch-修改预训练参数
我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。
1.查看模型参数
现模型:
1 model_dict = model.state_dict() 2 for k,v in model_dict.items(): 3 print(k)
预训练模型参数
1 pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 2 for k,v in pretrained_dict.items(): 3 print(k)
2.将预训练参数赋给自己改进的模型
改进的模型参数和原模型参数一致时:
1 import torch.utils.model_zoo as model_zoo 2 3 model_urls = { 4 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 5 } 6 7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。
改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :
1 model_dict = model.state_dict() #取出自己模型的网络参数 2 pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 3 4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2] 5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]