学习笔记11:预训练模型
什么是预训练网络
预训练模型就是之前用较大的数据集训练出来的模型,这个模型通过微调,在另外类似的数据集上训练。
一般预训练模型规模比较大,训练起来占用大量的内存资源。
微调预训练网络
我们采用vgg16作为预训练模型,来实现上一篇中四种天气的识别。
我们可以先来看一下vgg16的网络架构:
首先是一系列的卷积层和池化层
然后是一个全局池化层,全局池化层可以取代view
全局池化层之后,是分类器,而我们要改的就是这个分类器
分类器需要改的地方就只有最后的输出维度
模型加载及修改代码
model = models.vgg16(pretrained = True) # 加载模型,pretrained参数设置为True
for p in model.features.parameters():
p.requries_grad = False # 卷积层不变
model.classifier[-1].out_features = 4 # 分类器最后一个全连接层的输出维度改为4
注意训练的时候尽量使用gpu,不然的话内存可能会不够