模型微调

微调流程

  1. 在源数据集(source dataset)上预训练一个网络模型,即源模型(source model)
  2. 创建一个新的网络模型,即目标模型(target model)
  • 目标模型复制了源模型上除了输出层外的所有模型设计及参数。
  • 我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。此外,还假设源模型的输出层与源数据集的标签紧密相关,因此在目标模型中不与采用。
  1. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层参数。
  2. 在目标数据集上训练目标模型。
  • 从头训练输出层,其余层的参数均基于源模型的参数微调得到。
    模型迁移

训练特定层

若仅需改变最后一层模型参数,不改变其他层(特征提取层)参数,则先冻结其他层参数梯度,再对模型输出部分的全连接层进行修改。

import torchvision.models as models
# 加载一个预训练模型
model = models.resnet18(pretrained=True)

pretrained=True:使用预训练好的权重,默认状态pretrained=False,即不使用预训练权重。


def set_param_requires_grad(model, feature_extracting):
  if feature_extracting:
    for param in model.parameters():
      # param.requires_grad默认为True
      param.requires_grad=False

feature_extract = True
set_param_requires_grad(model,feature_extract)
# 修改模型
# 在之后的训练中,model只会在fc层进行梯度回传
model.fc = nn.Linear(in_featuers=512, out_features=4, bias=Tre)

注意事项

  • 通常PyTorch模型的扩展为.pt或.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。

  • 一般情况下预训练模型的下载会比较慢,我们可以直接查看自己的模型里面model_urls,然后手动下载

    • 预训练模型的权重在Linux和Mac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users<username>.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。
  • 如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
  • 如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。

参考:
https://datawhalechina.github.io/thorough-pytorch/第六章/6.3 模型微调-torchvision.html

posted @ 2022-04-06 22:18  ArdenWang  阅读(138)  评论(0编辑  收藏  举报