模型微调
微调流程
- 在源数据集(source dataset)上预训练一个网络模型,即源模型(source model)
- 创建一个新的网络模型,即目标模型(target model)
- 目标模型复制了源模型上除了输出层外的所有模型设计及参数。
- 我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。此外,还假设源模型的输出层与源数据集的标签紧密相关,因此在目标模型中不与采用。
- 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层参数。
- 在目标数据集上训练目标模型。
- 从头训练输出层,其余层的参数均基于源模型的参数微调得到。
训练特定层
若仅需改变最后一层模型参数,不改变其他层(特征提取层)参数,则先冻结其他层参数梯度,再对模型输出部分的全连接层进行修改。
import torchvision.models as models
# 加载一个预训练模型
model = models.resnet18(pretrained=True)
pretrained=True
:使用预训练好的权重,默认状态pretrained=False
,即不使用预训练权重。
def set_param_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
# param.requires_grad默认为True
param.requires_grad=False
feature_extract = True
set_param_requires_grad(model,feature_extract)
# 修改模型
# 在之后的训练中,model只会在fc层进行梯度回传
model.fc = nn.Linear(in_featuers=512, out_features=4, bias=Tre)
注意事项
-
通常PyTorch模型的扩展为.pt或.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。
-
一般情况下预训练模型的下载会比较慢,我们可以直接查看自己的模型里面model_urls,然后手动下载
- 预训练模型的权重在Linux和Mac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users<username>.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。
-
如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。
self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
- 如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。
参考:
https://datawhalechina.github.io/thorough-pytorch/第六章/6.3 模型微调-torchvision.html