Paper Reading: DeepSMOTE Fusing Deep Learning and SMOTE for Imbalanced Data
Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。
论文概况 | 详细 |
---|---|
标题 | 《DeepSMOTE: Fusing Deep Learning and SMOTE for Imbalanced Data》 |
作者 | Damien Dablain, Bartosz Krawczyk, Nitesh V.Chawla |
发表期刊 | IEEE Transactions on Neural Networks and Learning Systems |
发表年份 | 2022 |
期刊等级 | 中科院 SCI 期刊分区(2022年12月最新升级版)1 区,CCF-B |
论文代码 | https://github.com/dd1github/DeepSMOTE |
作者单位:
- Department of Computer Science and Engineering, and the Interdisciplinary Center for Network Science and Applications (iCeNSA), the University of Notre Dame, Notre Dame
- Department of Computer Science, Virginia Commonwealth University, Richmond
研究动机#
不平衡的类分布将影响分类器的训练过程,可能会导致模型的预测在少数类上具有较高的错误率,甚至完全遗漏少数类。最近的一些研究证实了不成比例的班级不是学习问题的唯一来源,偏斜的类失衡比例通常伴随着其他因素,如难分样本、边缘样本、小样本量、流数据的特殊性质等。深度学习方法具有出色的学习能力,但是同样很容易受到不平衡数据分布的影响。解决该问题的两个主要方向是设计损失函数和重采样,深度学习重采样方法要么基于像素,要么使用 GAN 进行样本生成。这两种方法都有很强的局限性,例如基于像素的方法通常不能捕获图像的复杂数据属性,基于 GAN 的解决方案需要大量的数据、难以调优、可能遭受模式崩溃。
文章贡献#
为了实现既能处理原始图像,又能保留原始图像的属性,并且能够生成既具有高视觉质量又能丰富深度模型判别能力的图像。本文在 SMOTE 方法的基础上提出了一种新的深度学习模型过采样算法 DeepSMOTE,由三个主要部分组成:Encoder/Decoder、SMOTE、用惩罚项增强的 loss 函数。该方法允许在深度学习模型中嵌入有效的人工实例,以实现简化的端到端过程,和 GAN 方法不同在于 DeepSMOTE 在训练中不需要鉴别器。将 DeepSMOTE 与多种现有的算法进行比较,使用五种流行的图像基准和三种专用的评价指标证明 DeepSMOTE 的性能更优。DeepSMOTE 生成的高质量人工图像既适合视觉检查,又具有丰富的信息,可以有效地平衡类别并减轻不平衡分布的影响。
本文方法#
本文提出的 DeepSMOTE 由一个 Encoder/Decoder 框架,该方法基于 SMOTE 的过采样方法和一个带有重建损失和惩罚项的 loss 函数组成,伪代码如下所示。
DeepSMOTE 的骨干网络是基于 Radford 等人建立的DCGAN 架构,以端到端方式进行训练。DeepSMOTE 将不平衡的数据集分批送入网络,在每批数据上计算重构损失。在训练过程中将使用所有类的样本,由于少数类样本很少,所以使用多数类的样本来训练模型以学习数据集整体固有的基本模式。这种基于类共享一些相似特征的假设,例如在手写体数据集中虽然数字 9 是少数类,数字 0 是多数类,但模型可以学习数字的基本轮廓。
模型将使用增强的 loss 函数,该函数除了重建损失外,还包含一个基于嵌入图像的重建的惩罚项。惩罚项的产生方式是在训练是从训练集中抽取一批图像,用 Encoder 将采样图像缩减到较低维的特征空间,再用 Decoder 按照与编码图像不同的顺序重构图像。例如编码顺序为 D0、D1、D2,解码时将顺序改为 D2、D0、D1。改变重构图像的顺序是为了有效地将方差引入到编码/解码过程中,打乱顺序后图像之间的 MSE 差值,就好像图像被 SMOTE 处理过一样。除了惩罚项,在训练 DeepSMOTE 时将选择一个类样本并计算样本与其邻居之间的距离来模拟 SMOTE 的方法。相对于 GAN 中被广泛使用的鉴别器,本文的使用惩罚项的方式更节省内存和计算效率。
DeepSMOTE 完成训练之后,就可以用编码器/解码器结构生成图像。编码器将原始输入减少到一个较低维的特征空间,在特征空间使用 SMOTE 过采样,然后解码器将过采样后的样本解码成图像。
实验结果#
数据集和实验设置#
实验选择了五个流行的数据集,分别是:MNIST、Fashion-MNIST、CIFAR-10、SVHN、CelebA,这些数据集的信息如下表所示。这些数据集是通过在每个类别中随机选择样本产生不平衡,MNIST、Fashion-MNIST 的不平衡比为 100:1,CIFAR-10、SVHN 和 CelebA 约为 56:1。
用于对比的重采样方法共有 6 种,分别是 4 种基于像素的过采样算法:SMOTE、基于自适应马氏距离的过采样(AMDO)、联合清洗和重采样(MC-CCR)、基于径向的过采样(MC-RBO),2 种基于 GAN 的过采样方法:平衡 GAN(BAGAN)、生成对抗少数过采样(GAMO)。所有重采样方法都使用 Resnet18 作为基分类器,性能指标使用 ACSA、GM、F1,并使用 Shaffer 事后检验和贝叶斯Wilcoxon 带符号秩检验。测试时使用五折倍交叉验证进行不平衡测试和平衡测试,不平衡测试的测试集类别的比例和训练集相同,平衡测试集的所有类的样本数量大致相等。
DeepSMOTE 的实现使用了 Radford 等人开发的 DCGAN 架构,编码器由四个卷积层组成,使用批处理归一化和 LeakyReLu 激活函数,隐藏层的维度为 64,最后的线性层根据数据集的维度确定。解码器结构由四个反卷积层组成,除最后一层使用 Tanh 外,其余层使用批归一化和 ReLU 激活函数。使用 Adam 作为优化器,学习率设置为 0.0002。
对比实验结果#
下图是不平衡的 MNIST 的 2D 投影,以及使用 BAGAN、GAMO、DeepSMOTE 过采样后的分布。可以看到 BAGAN 和 GAMO 都专注于独立饱和每个类的分布,以增强类边界和提高在过采样数据上训练的分类器的识别能力。DeepSMOTE 结合了过采样和惩罚函数,这样引入的样本可以降低了少数类的错误概率。
实验结果如下面两张表所示,可见本文的 DeepSMOTE 方法在两种测试集设置下都优于用于对比的算法。
下表是统计检验的结果,DeepSMOTE 以统计显著的方式优于所有方法(RQ1回答)。
生成图像的质量#
下面的多张图展示了 BAGAN、GAMO、DeepSMOTE 在 5 个数据集上生成的图像,可以看到 DeepSMOTE 生成图像的质量很高。
不同的不平衡比设置#
此处分析了 DeepSMOTE 对 [20,400] 范围内不同不平衡比率的鲁棒性,结果如下图所示,可见 DeepSMOTE 即使在最高的不平衡比率下也表现出出色的鲁棒性。
接着分析不同不平衡比下的模型稳定性,在 20 次重复的五折交叉验证下运行算法。下图显示了 3 种重采样方法的性能,阴影区域表示结果的方差。可见基于 GAN 的方法在较高的不平衡比率下显示出越来越大的方差,DeepSMOTE 在这些指标中返回最小的方差,显示了高稳定性。
优点和创新点#
个人认为,本文有如下一些优点和创新点可供参考学习:
- SMOTE 方法通常都是应用于表格型数据,此处将该方法应用于不平衡的图像生成问题,具有创新性;
- 该方法结合了 Encoder/Decoder 框架并设计了一个增强的损失函数,不需要鉴别器,且基于调换顺序实现的惩罚项也非常巧妙;
- 在训练时本文方法没有简单地忽略某一类,而是从数据集本身的特征去考虑问题,考虑问题的视角也有所不同;
- 模型的实现思路简单、清晰,通过实验证明算法的性能优秀,实现的方式非常优雅。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
2021-08-03 操作系统:进程同步
2020-08-03 链路层:以太网