Paper Reading: A Re-Balancing Strategy for Class-Imbalanced Classification Based on Instance Difficulty
Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。
论文概况 | 详细 |
---|---|
标题 | 《A Re-Balancing Strategy for Class-Imbalanced Classification Based on Instance Difficulty》 |
作者 | Sihao Yu, Jiafeng Guo, Ruqing Zhang, Yixing Fan, Zizhen Wang, Xueqi Cheng |
发表会议 | IEEE Conference on Computer Vision and Pattern Recognition(CVPR) |
发表年份 | 2022 |
期刊等级 | CCF-A |
论文代码 | 未公开 |
作者单位:
- University of Chinese Academy of Sciences, Beijing
- China CAS Key Lab of Network Data Science and Technology, Institute of Computing Technology
- Chinese Academy of Sciences, Beijing, China
研究动机#
各种算法都能解决高质量的合成数据集的分类问题,然而现实应用中的数据集往往是不平衡的。在不平衡数据集上训练时,神经网络模型通常在少数类上表现不佳。为了解决这个问题,很多研究者在训练模型时引入了各种策略来重新平衡数据分布,主流的解决方案是重采样和重加权。重采样方法通过重复少数类的实例或去除部分多数类的实例,实现对训练数据的直接调整。重加权方法侧重于对成本的定义,实现对少数类的成本给予更多的关注。
然而这些类级别的重新平衡策略过于粗糙,无法区分实例的差异。在多数类中也有一些难分样本,如果通过减少多数类的权重,这些难分将被进一步忽略,变得更加难以学习。因为每个实例通常只出现一次,所以在样本级的平衡问题上不能像现有的类级方法那样分配权重。受人类学习过程的启发可知,学习的速度和难度通常是密切相关的,样本学习过程也会直接受到数据分布的影响。因此将学习速度慢的样本识别为更困难的样本,并在学习中增加其权重,可以有效地平衡数据分布。
文章贡献#
受人类学习过程的启发,本文根据学习速度设计了样本难度模型,并提出了一种新的实例级再平衡策略。具体来说模型在每个训练周期记录每个实例的预测,并根据预测的变化来测量该样本的难度难度。然后对困难实例赋予更高的权重,对数据进行重新采样。本文从理论上证明了提出的重采样策略的正确性和收敛性,并进行一些实证实验来展示本文算法的能力。
本文方法#
任务定义#
在不损失一般性的情况下,本文将问题定义为一个包含 k 个类的分类任务。设 S:={zi=(xi, yi):1≤i≤N} 为包含 N 个实例的训练数据,其中 zi 表示第 i 个实例,xi 表示其特征,yi∈{1,…,k} 表示其类别标签。
采用神经网络拟合特征到类标签的映射,假设 Net 的最后一层是 softmax,将 pi=Net(xi) 表示为实例 zi 的预测分布,argmax(pi)=yi 表示实例 zi 被 Net 正确推断。采用合适的损失函数 L 来学习 Net 的参数 θ,设 L(θ,S)=∑Ni=1L(θ,zi)是二次可微的,学习目标是通过改变网络的 θ 来最小化总损失 L(θ,S)。
重采样框架#
数据分布会极大地影响模型的学习,本文方法只调整训练中使用的数据分布来优化模型。使用的重采样方法的整体训练框架如伪代码所示:
重抽样方法的核心是计算所有实例的抽样权重,与现有的类级方法不同,即使类别相同本文方法中每个实例的采样概率可以不同。该模型的灵感基于应该更多关注困难样本的观点,抽样权重的计算公式如下,其中 wi,t 决定实例 zi 在第 t 次迭代后的采样概率。
采样方法将根据每个实例当前的难度动态调整采样权值,因此确定难度的模型直接决定了每个实例的抽样概率,是本文方法的核心。
样本难度模型#
理论分析#
当使用梯度下降更新 θ 时,更新的目标是使 L(θ,S) 变小。根据泰勒展开式,当 θ→θ0 时,L(θ,S) 可近似为如下公式,其中 L'(θ0, S)=∑Ni=1L'(θ,zi)。为了得到最快的下降速度,令 △θ =(θ−θ0)=−ηL′(θ0, S),其中 η 为学习率。
设参数更新后由 θ0 变为 θ1=θ0-ηL'(θ0, S),对于样本 z 的 loss=L(θ1,z)−L(θ0,z) 的变化可以估计为如下公式,其中 ⟨·,·⟩ 表示内积。
如果 ⟨L'(θ0,z),L'(θ0,S⟩<0,则 z 的损失将增加,此时称 z 为未学习的。可以据此构造 S 的两个子集,分别是辅助集 Az:={a:a∈S,⟨L'(θ0,z), L'(θ0,a)⟩>0} 和阻碍集 Hz:={r:r∈S,⟨L'(θ0,z),L'(θ0,r)⟩<0}。实例损失的减小表示实例被模型学习的程度,由此可见 Az 中的样本提供了帮助,而 Hz 中的样本则产生了阻碍。当 Az 中任意实例的权重降低或 Hz 中任意实例的权重增加时,z 的损失将变得更加难以减少,说明实例的难易程度受数据分布的影响。
此时的一种朴素的想法是,学习的样本难度可以通过所有梯度两两求内积来评估,但是这样的计算非常复杂,速度太慢。由于二阶可微假设,当模型参数受到轻微扰动时,梯度变化不大。在相邻的两次迭代中,当 △θ 较小时模型的预测变化趋势相似,因此可以利用上一次迭代中的预测变化来估计当前迭代中的变化。本文将 Nett 表示为已经训练了 t 次迭代的模型,pi,t=Nett(xi) 表示为 Nett 的预测分布。则在第 t+1 次训练迭代中,取 pi,t-1 和 i,t 之间的变化来估计 zi 是倾向于被学习还是未被学习的学习结果。如果一个实例经常是未学习的,那么模型将很难利用它进行训练,显然当它的权重增加时,这样的样本将更容易学习。
样本难度可以在向量空间中,难度向量被定义为 D->i,T=c->+∑Tt=1d->i,t,其中 c=(c, c)、d->i,t=(dui,t, dli,t)。更新采样权值时每个样本会计算一个新的难度向量,该向量的难度空间由遗忘趋势和学习趋势组成。如果模型倾向于忘记一个实例,则其难度向量的方向将紧密地指向忘记趋势。如果模型倾向于学习一个实例,其难度向量的方向将紧密指向学习趋势。该难度最终的结果会随着模型的收敛而收敛,在理想情况下,本文作者证明了当t→∞ 时,||D->i,t-1||=||D->i,t||,该结果作者展示在论文的附录中。
模型设计#
确定样本的难度考虑了学习方向和遗忘方向的预测变化,具体来说对于给定实例 zi,其经过 t 次迭代后的难度估计为如下公式。其中 c 为实例难度的先验参数,dui,t为 t 次迭代后在负学习方向上的预测变化,dli,t 为在学习方向上的预测变化。所有样本都有相同的 c,该参数调节了难度对预测变化的敏感性。
上述公式中的分子记录负学习趋势的积累,分母记录了学习趋势的积累。在初始化阶段,所有样本的困难被初始化为 Di,0=1。对于任何样本 zi 和 zj,Di,t>Dj,t的实例说明在更新了 t 次迭代之后,zi 比 zj 更困难。
本文中根据 PSI(Population Stability Index, 稳定度指标)来定义 dui,t 和 dli,t。PSI 是一个衡量分布之间距离的指标,无论学习方向如何,预测值在 pi,t-1 和 pi,t 之间的变化如下公式定义,其中 pji,t-1 表示 Nett 预测 zi为类别 j 的概率。
其中 pyii,t−pyii,t-1>0 表示学习,pyii,t−pyii,t-1<0 表示负学习。由于数据集往往有多个类别,因此将 dui,t 和 dli,t 定义为如下公式,使其满足 di,t = dui,t+dli,t。
模型每次迭代时都需要记录当前模型的预测,通过计算实例难度来调整实例 zi 的抽样权值。
实验结果#
实验设置#
本文和以下 3 种 baseline 方法进行对比,使用的 baseline 都是用于优化给定基模型学习的调节方法,实验采用 3 种基础模型,分别是:ResNet、Multilayer Perceptron、logistic regression。
baseline | 描述 |
---|---|
Class-Balance Loss | 类级别的再平衡方法,根据类的有效数量调整类的权重 |
Focal loss | 实例级难度敏感方法,根据损失为实例分配更高的权重 |
TDE | 根据推理中的因果效应去除多数类的累积偏好,且不会在训练过程中调整数据分布 |
长尾分类实验#
此部分对具有不同失衡比率的长尾 CIFAR 数据集进行实验,总体失衡比记为 nmax/nmin,实验数据如下所示。在大多数情况下,本文方法优于其他方法,说明本文的策略对重新平衡分布是有效的。由于 TDE 不改变数据分布,因此本文方法可以与 TDE 集成进一步提高性能。
在不平衡比为 100 的长尾 CIFAR-100 上观察0多数类和少数类的性能,实验结果如下表所示。从结果可见,本文的方法更关注那些很难学习的实例,因此能有效地改进多数类和少数类。虽然少数类的改进不如 Class-Balance Loss,但总体上有更好的性能。
仿真实验#
原始 MNIST 是手写体数字识别数据集,此处通过欠采样构建了 MNIST 的长尾版本。修改后的数据集有偶数、奇数两个类别,其中偶数类包括多数子类 0、2 和少数子类 4、6、8,奇数类的多数子类 1、3 和少数子类 5、7、9,每个类中的子类分布也服从长尾分布。在训练过程中,模型能获取类标签,但对于子类标签是不可知的。
实验结果如下表所示,可见本文方法对少数子类有更显著的改进,这表明了样本级调整具有优越性,同时本文方法也有效地降低了标记类和未标记子类的不平衡比率。
由于本文方法可以重新平衡类甚至未标记子类的分布,因此作者记录了不同类和子类的遗忘频率,并可视化了遗忘频率和难度之间的关系。从下图可以看到,较少的类或子类通常具有较高的遗忘频率,表明数据分布与遗忘频率之间存在很强的相关性。
接着对具有不同遗忘概率的实例进行模拟,下图给出了三种不同概率的负学习样本的损失和难度变化。可以看到本文的样本难度与遗忘频率是一致的,具有较高遗忘频率的实例将获得更高的难度和更高的权重,因此本文方法可以重新平衡类或子类的分布。
一般性实验#
本文在 10 个 UCI 数据集上也做了实验,其结果如下表所示。结果表明本文方法可以稳定地提高基础模型在微小数据集上的性能,说明了本文的方法具有通用性。
优点和创新点#
个人认为,本文有如下一些优点和创新点可供参考学习:
- 和一些类级别的不平衡方法不同,本文聚焦于实例级别的方法研究,具有更细的粒度;
- 对于每个实例,本文通过利用 PSI 指标确定了样本的学习趋势和负学习趋势,进而度量了某个样本的分类难度,这种对于每个实例的分析方法可以借鉴;
- 对于本文提出的优化方式,作者提供了数学证明和可解释性说明,说服力更强。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
2021-07-24 《剑指 Offer》学习记录:题 68:二叉树的最近公共祖先
2020-07-24 运输层:TCP 拥塞控制