PyTorch 中的 CIFAR10 图像分类
PyTorch 中的 CIFAR10 图像分类
如何为 CIFAR10 构建高精度 CNN
在本文中,我们将深入探讨 CIFAR10 图像分类问题。为了解决这个问题,我们将使用著名的深度学习库 PyTorch .
大纲
- 设置
- 数据采集
- 数据分析
- 卷积神经网络
设置
由于我们要针对大量数据训练重型神经网络,因此建议使用 GPU 提供的在线工具,例如 谷歌实验室 ,甚至你的机器,如果它有足够的硬件。
首先,我们需要安装额外的 pip 包。如果你使用 谷歌实验室 ,这些是您必须安装的唯一软件包。
!pip 安装射线
!pip install hpbandster 配置空间
然后,我们可以导入整个项目所需的所有库。
为了让我们的网络一旦可用就可以转移到 GPU 上并在那里进行训练,我们必须定义 设备
多变的。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
数据采集
由于我们使用的是 PyTorch,因此 __CIFAR10 数据集在 Torchvision.datasets
模块,我们可以在我们的代码中直接从那里下载它。
Code 1 • Data collection
这 转换
参数,用于训练集、验证集和测试集:
- 转换
numpy.ndarray
到一个火炬.张量
.张量可以在 GPU 上使用以加速计算。 - 标准化像素值,从
[0, 255]
至[-1, 1]
.这是让 CNN 发挥最佳性能的最佳范围。
然后它应用 随机裁剪
和 随机水平翻转
仅适用于火车组,以实现更好的性能,特别是:
随机水平翻转
将简单地以 0.5 的概率随机水平翻转图像。随机裁剪
在每边添加 4 个像素的填充后,将裁剪图像为指定尺寸 (32x32)。这padding_mode='反射'
将来自火炬文档:
“具有图像反射的焊盘,不重复边缘上的最后一个值” .
我们最终定义了一个有用的 听写 返回给定相同索引的文本标签。
IDX_TO_LABEL = {v: k for k, v in trainset.class_to_idx.items()}
数据分析
这 CIFAR10 数据集由 60000 个 32x32 彩色图像 (RGB) 组成,分为 10 个类别。训练集 50000 张图像,测试集 10000 张图像。
您可以通过运行下面的代码片段来获取有关数据集的这些信息和其他信息。
Code 2 • Samples and classes
您还可以使用下一个代码块轻松显示一些图像示例。
Code 3 • Image examples
输出应如下所示:
Image 1 • Dataset images example
数据集分布
数据集分布是机器学习的一个关键话题。不平衡的数据集(每个类别的示例数量不同)可能会阻碍我们的模型实现良好的整体准确性。每个类都应包含相同数量的示例,以使数据集完美平衡。
下面的代码片段是显示数据集分布的有效方式。
Code 4 • Dataset distribution charts
Image 2 • Training set examples distribution
Image 3 • Test set examples distribution
如上图所示,这 10 个类别已经完全平衡。
卷积神经网络
我们在这里,本文的核心是:构建 CNN。
卷积神经网络 (CNN) 是一类人工神经网络 (ANN),最常用于执行视觉任务,例如图像分类。我们将看到 CNN 在 CIFAR10 数据集上的图像分类任务中表现如何。
不过,如果您还不熟悉 CNN,我建议您退后一步,在继续阅读之前先看看这份综合指南:
[
卷积神经网络综合指南——ELI5 方式
人工智能在弥合人类能力差距方面取得了巨大的进步……
向datascience.com
网络设计
网络的结构对于构建高精度模型非常重要。我们不会深入探讨卷积层、全连接层、池化、内核以及其他我将您推荐给您的文章的主题。
网络设计类似于 ResNet,卷积层与残差块交替,最后附加一个全连接层用于分类。这种类型的网络已被证明适用于图像分类任务,例如我们正在尝试挑战的任务。
在每次卷积之后应用批量归一化,这是一种对网络进行正则化的方法,也可以减少训练时间。这种技术比 dropout 更有效,并在许多现代卷积架构中取代了它。
它的灵感来自 这个系列 的博客文章。
Code 5 • Neural **** network design
如果定义为 Class 的模型对您来说听起来很奇怪,请考虑查看有关如何定义神经网络的官方 PyTorch 文档:
[
神经网络 — PyTorch Tutorials 1.12.1+cu102 文档
可以使用该软件包构建神经网络。现在您已经了解了 , 取决于定义模型和…
pytorch.org
](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html)
网络结构
这 火炬总结
工具可以帮助我们以整洁的方式展示网络结构。
净 = ResNet()
net = net.to(设备)
摘要(净,(3,32,32),batch_size=32)
输出应如下所示:
-------------------------------------------------- --------------
层(类型)输出形状参数#
==================================================== ===============
Conv2d-1 [32, 64, 32, 32] 1,792
BatchNorm2d-2 [32, 64, 32, 32] 128
ReLU-3 [32, 64, 32, 32] 0
Conv2d-4 [32, 128, 32, 32] 73,856
BatchNorm2d-5 [32, 128, 32, 32] 256
ReLU-6 [32, 128, 32, 32] 0
MaxPool2d-7 [32, 128, 16, 16] 0
Conv2d-8 [32, 128, 16, 16] 147,584
BatchNorm2d-9 [32, 128, 16, 16] 256
ReLU-10 [32, 128, 16, 16] 0
Conv2d-11 [32, 128, 16, 16] 147,584
BatchNorm2d-12 [32, 128, 16, 16] 256
ReLU-13 [32, 128, 16, 16] 0
Conv2d-14 [32, 256, 16, 16] 295,168
BatchNorm2d-15 [32, 256, 16, 16] 512
ReLU-16 [32, 256, 16, 16] 0
MaxPool2d-17 [32, 256, 8, 8] 0
Conv2d-18 [32, 512, 8, 8] 1,180,160
BatchNorm2d-19 [32, 512, 8, 8] 1.024
ReLU-20 [32, 512, 8, 8] 0
MaxPool2d-21 [32, 512, 4, 4] 0
Conv2d-22 [32, 512, 4, 4] 2,359,808
BatchNorm2d-23 [32, 512, 4, 4] 1.024
ReLU-24 [32, 512, 4, 4] 0
Conv2d-25 [32, 512, 4, 4] 2,359,808
BatchNorm2d-26 [32, 512, 4, 4] 1.024
ReLU-27 [32, 512, 4, 4] 0
MaxPool2d-28 [32, 512, 1, 1] 0
展平-29[32, 512] 0
线性 30 [32, 10] 5,130
==================================================== ===============
总参数:6,575,370
可训练参数:6,575,370
不可训练参数:0
-------------------------------------------------- --------------
输入大小(MB):0.38
向前/向后传递大小 (MB):290.25
参数大小(MB):25.08
估计总大小(MB):315.71
-------------------------------------------------- --------------
功能
在本节中,我们将定义一些有用的函数来包装主要操作代码。
数据加载器
需要将训练集拆分为 装车机
和 有效加载器
,以便计算验证集上的一些指标。这种拆分是按照 有效尺寸
.稍后将在创建 Dataloaders 期间计算和使用用于训练和验证的采样器。
Code 6 • Dataloaders
功能 数据加载器
为 动车组
, 验证集
, 和 测试集
. DataLoader 只是一个抽象小批量概念的迭代器,在训练网络时非常有用。
它还适用于 转换
在数据收集部分定义。
火车 CIFAR10
功能 train_cifar
定义网络的训练过程。由于它也将在超参数调整阶段被调用,因此关键字参数 调音
已添加,以自定义一些特定的行为。
它通过 装车机
首先,对于训练阶段,以及 有效加载器
接下来,用于计算每个时期的验证损失和准确性:这些指标将用于评估和保存最佳模型训练参数。
一些重要的选择:
- 交叉熵损失 作为损失函数
- 亚当 优化器,具体参数稍后在 hp 调整部分设置
Code 7 • Training function
评估
以下函数封装了用于模型测试和性能评估的代码。
Code 8 • Evaluation functions
超参数调优
我们将使用 Ray Tune 进行超参数调优。搜索空间包括:
批量大小
.LR
,学习率。贝塔1
和Beta2
系数,用于计算梯度及其平方的运行平均值 ( 亚当优化器 )。阿姆斯格勒
,一个布尔值,指示是否使用 Adam 的 AMSGrad 变体,来自论文: 关于亚当和超越的收敛 .
使用的搜索算法是 BOHB(贝叶斯优化 HyperBand) :
BOHB(贝叶斯优化 HyperBand)是一种既可以终止不良试验,又可以使用贝叶斯优化来改进超参数搜索的算法。
https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#bohb-tune-search-bohb-tunebohb
它必须与特定的调度程序类配对: HyperBandForBOHB .
Code 9 • Hyperparameters tuning using RayTune
请注意 使用这些设置执行此代码可能需要几个小时,具体取决于您机器的硬件。为了使其更快,您可以删除一些配置参数或减少 num_samples
.
最佳试验结果
您最终可以从最佳试验中打印出最佳参数和其他统计信息,如下所示:
Code 10 • Best trial results from hyperparameters optimization
训练网络
在定义函数并通过一些手动或自动调整找到最佳超参数后,我们终于可以训练网络了。此示例中使用的超参数是通过自动调整和一些手动调整找到的。
Code 11 • Network training
在这个例子中,我们使用了 50 个 epoch。输出是这样的:
时期:0 训练损失:1.225798 验证损失:0.855928 验证准确度:0.697900
验证损失减少(inf --> 0.855928)。保存模型...
时期:1 训练损失:0.810272 验证损失:0.698423 验证准确度:0.760500
验证损失减少(0.855928 --> 0.698423)。保存模型...
时期:2 训练损失:0.655569 验证损失:0.618111 验证准确度:0.783800
验证损失减少(0.698423 --> 0.618111)。保存模型...
时期:3 训练损失:0.554775 验证损失:0.501839 验证准确度:0.825100
验证损失减少(0.618111 --> 0.501839)。保存模型...
时期:4 训练损失:0.489532 验证损失:0.481625 验证准确度:0.835100
验证损失减少(0.501839 --> 0.481625)。保存模型...
时期:5 训练损失:0.430423 验证损失:0.412588 验证准确度:0.861600
验证损失减少(0.481625 --> 0.412588)。保存模型...
时期:6 训练损失:0.387058 验证损失:0.397691 验证准确度:0.861200
验证损失减少(0.412588 --> 0.397691)。保存模型...
时期:7 训练损失:0.350921 验证损失:0.372295 验证准确度:0.872300
验证损失减少(0.397691 --> 0.372295)。保存模型...
时期:8 训练损失:0.323667 验证损失:0.370949 验证准确度:0.871800
验证损失减少(0.372295 --> 0.370949)。保存模型...
时期:9 训练损失:0.298934 验证损失:0.349325 验证准确度:0.883800
验证损失减少(0.370949 --> 0.349325)。保存模型...
时期:10 训练损失:0.271154 验证损失:0.336918 验证准确度:0.888200
验证损失减少(0.349325 --> 0.336918)。保存模型... [...] 时期:40 训练损失:0.034805 验证损失:0.266108 验证准确度:0.921400
时期:41 训练损失:0.033066 验证损失:0.272571 验证准确度:0.921100
Epoch:42 训练损失:0.032195 验证损失:0.289000 验证准确度:0.914500
Epoch:43 训练损失:0.029787 验证损失:0.261842 验证准确度:0.924200
验证损失减少(0.262288 --> 0.261842)。保存模型...
时期:44 训练损失:0.029429 验证损失:0.279766 验证准确度:0.918700
Epoch:45 训练损失:0.028022 验证损失:0.269373 验证准确度:0.924700
Epoch:46 训练损失:0.026465 验证损失:0.280967 验证准确度:0.921400
时代:47 训练损失:0.024350 验证损失:0.280006 验证准确度:0.921100
时代:48 训练损失:0.025161 验证损失:0.271155 验证准确度:0.923000
Epoch:49 训练损失:0.021592 验证损失:0.286109 验证准确度:0.920000 **------------ 完成培训 ------------**
在 44 个 epoch 后达到了最佳模型:
- 0.261842 验证损失
- 0.924200 验证准确性
我们还可以显示训练进度:
Code 12 • Displaying training results
Image 4 • Training progress charts
有趣的是,模型如何达到大约 80% 的准确率,之后 只有 3 个 epoch .
评估
最后,让我们检查一些评估指标。
Code 13 • Network evaluation
**网络的整体准确性** 91.66 %
在 10000 张测试图像上 **每类准确度** 飞机 93.40 %
汽车 96.30 %
鸟 89.50 %
猫 83.30 %
鹿 92.10 %
狗 85.90 %
青蛙 93.40 %
马 92.40 %
出货 94.70 %
卡车 95.60 %
这 实现的整体准确率超过 91% , 有的班级超过 95%。
伸出援手 领英 征求意见和建议。
看看笔记本
[
GitHub — mattiolato98/CIFAR10-image-classification: CIFAR10 图像分类使用…
您目前无法执行该操作。您使用另一个选项卡或窗口登录。您在另一个选项卡中退出或...
github.com
](https://github.com/mattiolato98/CIFAR10-image-classification)
参考
- CIFAR10
- 从小图像中学习多层特征 , Alex Krizhevsky, 2009 [第 3 章]
- PyTorch
- 使用 PyTorch 进行 RayTune
- 火炬视觉
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明