————致力于用代码改变世界

关于深度学习模型不收敛问题解决办法

1. 问题重现

笔者在训练Vgg16网络时出现不收敛问题,具体描述为训练集准确率和测试集准确率一直稳定于某一值,如下图所示。

image

2. 可能的原因

2.1 数据问题

  • 噪声数据。不平衡的数据集、含有噪声或异常值的数据可能导致模型难以学习,尝试更换数据集,出现这种问题比较难办。

  • 数据预处理。确保数据质量,包括数据清洗、标准化、归一化等。
    示例:
    transforms.Normalize(mean=0.4,std=0.1)、
    transforms.Resize(size=(228,228),interpolation=InterpolationMode.BICUBIC)

  • 数据增强。可以使用以下方法扩充数据集。
    随机旋转(RandomRotation):随机旋转图像一定角度。
    随机裁剪(RandomCrop):从图像中随机裁剪出指定大小的子图。
    水平翻转(HorizontalFlip):水平翻转图像。
    垂直翻转(VerticalFlip):垂直翻转图像。
    颜色变换(ColorJitter):随机改变图像的色彩。
    调整亮度(AdjustBrightness):调整图像的亮度。
    调整对比度(AdjustContrast):调整图像的对比度。
    调整饱和度(AdjustSaturation):调整图像的饱和度。

2.2 模型设计问题

  • 更换模型:模型过于复杂或过于简单都可能导致不收敛。读者可以尝试将复杂的模型改为简单的模型,例如下图所示。笔者并不建议直接更换整个网络框架,因为这样工作量太大并且可能会出现新的问题。
    image

  • 简化:简化是指化简模型的某一部分而整体框架不变。★ 例如在Vgg16中(上图左半部分),第一个Convolution+ReLU后输出的特征通道数为64,接着第二个Convolution+ReLU,.....等等,那么是否可以简化网络使每个Convolution+ReLU输出的特征图通道变成32?或16?或8?,这样或许有效。 ★ Vgg16每次降采样后特征图大小都是原来的1/2,那么是否可以采用降采样更大的倍数呢?即使用更大的步长(stride)。

2.3 激活函数问题

不同的模型采用不同的激活函数可能会有不同的效果,若出现不收敛状况后,不妨改变一下激活函数?
例如将Sigmod激活函数改成ReLu?或LeakyReLU?
image

2.4 超参数问题

  • 调整DataLoader的batch_size参数,例如16、32、64、128、256等等,都尝试一下看看有没有效果。试着将DataLoader的shuffle参数改为True?
    示例:
    DataLoader(dataset=data,batch_size=16,shuffle=True,drop_last=False)

  • 学习率过高或过低。★ 笔者经常使用的学习率一般为0.05、0.01、0.005、0.001、0.0005、0.0001,尝试改变试试,或许有效果呢? ★ 尝试使用动态学习率,例如余弦退火学习率CosineAnnealingLR(下图)或其他?
    image
    示例:
    torch.optim.lr_scheduler.CosineAnnealingLR(optimer, T_max=20)

  • 分析学习率。可以结合学习率变化和验证集准确性调整学习率上下界,具体分析如下图所示。
    image

2.5 梯度问题

梯度消失或爆炸是导致模型不收敛的常见原因。通常做法是在卷积层(Convolution)后添加批量归一化(BatchNorm)再添加激活函数(例如ReLU等)。
示例:
nn.Conv2d(in_channels=,out_channels=,kernel_size=,stride=,padding=),
nn.BatchNorm2d(num_features=),
nn.ReLU(inplace=True),

2.6 优化器选择不当

不同的优化算法适用于不同类型的问题,错误的选择可能会阻碍模型的学习过程。改变不同的优化器,例如:

  1. 随机梯度下降法(Stochastic Gradient Descent,SGD)
  2. SGDM(带动量的SGD:SGD with momentum)
  3. 加速梯度(Nesterov Accelerated Gradient,NAG)
  4. 自适应动量优化(Adaptive Moment Estimation,Adam)
  5. ..等等

示例:
torch.optim.Adam(model.parameters(),lr=LEARN_RATE)

2.7 损失函数选择不当

  • 二分类使用nn.BCELoss(),多分类使用nn.CrossEntropyLoss() 等等。
  • 正则化技术。如L1/L2正则化、Dropout等可以防止过拟合。

3. 总结

以上是笔者学习过程中的经验,一般使用一种或多种便可解决不收敛问题,需要活学活用。

posted @ 2024-11-03 15:04  hello_nullptr  阅读(35)  评论(0编辑  收藏  举报