学习笔记11:预训练模型

什么是预训练网络

预训练模型就是之前用较大的数据集训练出来的模型,这个模型通过微调,在另外类似的数据集上训练。
一般预训练模型规模比较大,训练起来占用大量的内存资源。

微调预训练网络

我们采用vgg16作为预训练模型,来实现上一篇中四种天气的识别。
我们可以先来看一下vgg16的网络架构:
首先是一系列的卷积层和池化层

然后是一个全局池化层,全局池化层可以取代view
全局池化层之后,是分类器,而我们要改的就是这个分类器
分类器需要改的地方就只有最后的输出维度

模型加载及修改代码

model = models.vgg16(pretrained = True) # 加载模型,pretrained参数设置为True

for p in model.features.parameters():
    p.requries_grad = False             # 卷积层不变

model.classifier[-1].out_features = 4   # 分类器最后一个全连接层的输出维度改为4

注意训练的时候尽量使用gpu,不然的话内存可能会不够

posted @ 2021-01-30 11:47  pbc的成长之路  阅读(540)  评论(0编辑  收藏  举报