网络训练细节
经典网络的加载和初始化:
pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并且提供了预训练模型,可通过调用来读取网络结构和预训练模型(模型参数)。往往为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数。加载model如下所示:
import torchvision.models as models #加载网络结构和预训练参数 #参数pretrained在默认情况下是False,表示只加载网络结构而不加载预训练参数来初始化 resnet34 = models.resnet34(pretrained=True) #打印网络结构 print(resnet34) #PyTorch中通用的用一个模型的参数初始化另一个模型的层 #path_params.pkl为预训练模型参数的保存路径 #调用model的load_state_dict方法用预训练的模型参数来初始化自己定义的新网络结构,该方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度) resnet18.load_state_dict(torch.load(path_params.pkl)) #当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。 #打印网络结构 print(resnet18) #注意:cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype #按键值将对应模型参数加载到pre_dict pre_dict = resnet18.state_dict() #打印模型各层名及参数 for k, v in pre_dict.items(): print(k) #model是自己定义好的新网络模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)。 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
更改网络结构进行训练:
1. 参数修改
对于简单的参数修改,这里以resnet
预训练模型举例,resnet
源代码在Github。 resnet
网络最后一层分类层fc
是对1000种类型进行划分,对于自己的数据集,如果只有9类,修改的代码如下:
import torchvision.models as models #调用模型 model = models.resnet50(pretrained=True) #提取fc层中固定的参数 fc_features = model.fc.in_features #修改类别为9 model.fc = nn.Linear(fc_features, 9)
2. 增减卷积层
前一种方法只适用于简单的参数修改,有时候往往要修改网络中的层次结构,这时只能用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet
预训练模型举例。
3. 训练特定层,冻结其它层
另一种使用预训练模型的方法是对它进行部分训练。具体做法是,将模型起始的一些层的权重保持不变,重新训练后面的层,得到新的权重。在这个过程中,可多次进行尝试,从而能够依据结果找到frozen layers和retrain layers之间的最佳搭配。
如何使用预训练模型,是由数据集大小和新旧数据集(预训练的数据集和自己要解决的数据集)之间数据的相似度来决定的。
Pytorch保存与加载网络模型的两种方式:(推荐第二种)
1. 保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net;
torch.save(model_object, 'model.pkl') # 保存整个神经网络的结构和模型参数
重载:model = torch.load('model.pkl') #重载并初始化新的神经网络对象。
2. 保存神经网络的训练模型参数,save的对象是net.state_dict()。
torch.save(model_object.state_dict(), 'params.pkl') # 只保存神经网络的模型参数
需要首先导入对应的网络,通过model_object.load_state_dict(torch.load('params.pkl'))完成模型参数的重载和初始化新定义的网络。
dropout的应用:
批量归一化:
推荐学习:
本文的内容主要参考: https://www.cnblogs.com/wmlj/p/9917827.html