Typesetting math: 100%

学习理论:预测器-拒绝器多分类弃权学习

目前确定去京大读博了,预计方向是学习理论(Learning Theory)。熟悉我的朋友可能知道,虽然我读研期间的方向主要是联邦学习和推荐系统,但是我也会更新一些理论相关的博客,因为我确实对理论方向比较感兴趣。目前准备10月份左右入学,在这之前接受导师的线上指导开始科研。现在就以以这篇博客做为我PhD科研的开始吧(#^.^#)。

1 导引

弃权学习(learning with abstention) [1]主要是为了使分类器在学习过程中可能出现的误导性或者不正确的信息时(这常被称为“幻觉”),能够对做出预测进行弃权。目前,弃权学习的方法主要可以分为以下几种:

  • 基于置信度的方法(confidence-based methods)。这种方法在预训练模型返回的分数低于某个阈值θ时弃权。
  • 选择性分类(selective classification)。设置一个预测器和一个选择器,并定义被期望的选择或收敛度归一化的选择风险或损失。
  • 预测器-拒绝器公式(predictor-rejector formulation)。同时学习一个预测器和一个拒绝器,它们来自不同的函数族,这种方法显式地考虑了弃权花费c,当学习器弃权时将导致大小为c的损失;
  • 基于分数的公式(score-based formulation)。对多分类类别进行增广(多一个拒绝标签类型),当分配给拒绝标签的分数最高时进行弃权。

本文关注预测器-拒绝器公式,也即显式地建模弃权花费的一种方法。那么该如何对多分类弃权问题进行形式化,什么时候适合弃权呢?

我们先来考虑有监督二分类弃权学习场景。在这种场景中标签为Y={1,+1},样本独立同分布地采样自X×Y空间上的固定未知分布D。给定实例xX,学习器若选择对预测x的标签进行弃权,则产生一个损失c(x)[0,1]做为代价;否则,使用预测器h做出预测h(x)并产生一个标准的0-1损失Iyh(x)0(其中y为真实标签)。由于随机猜测可以达到12的期望代价,拒绝操作只有在c(x)<12时是合理的。

我们使用(h,r)来建模学习器,其中函数r:XR使得点xXr(x)0时被拒绝,假设h:XR预测未被拒绝的点的标签(体现为h(x)的正负)。对任意(x,y)X×Y(h,r)的弃权损失[2]定义如下:

Labst(h,r,x,y)=Iyh(x)0Ir(x)>0+c(x)Ir(x)0

假定对于学习器来说弃权花费c(x)是已知的。在接下来的分析中,我们假设c是一个常值函数,但我们的部分分析可以应用于更普遍的情况。

HR为两个从XR的函数构成的函数族。此外,我们假设带标签样本S=((x1,y1),,(xm,ym))独立同分布地采自Dm。则学习问题即为确定一个(h,r)H×R以使得下列期望弃权损失R(h,r)尽可能小:

RLabst(h,r)=E(x,y)DLabst(h,r,x,y)

对于大多数假设集而言,优化上述期望弃权损失RLabst(h,r)是难以处理的(intractable)。因此,在这种情况下学习算法需要依赖于代理损失(surrogate loss )。那么,一个重要问题就是什么种类的代理损失能够被用于替代目标弃权损失。直觉上,一个代理损失需要易于优化,且其最小化会导向目标损失的最小化。术语校准(calibration) 就用于定义这样一种损失函数,这种损失函数能够确保风险最小化的预测器能够成为贝叶斯最优(Bayes-optimal) 分类器,这种性质被称为贝叶斯一致性。下图直观地展现了校准的代理损失的性质[4](其中目标损失ϕ01是二分类中常见的0-1损失,ϕ是其代理损失):

出于理论分析目的,直接定义预测器和拒绝器的校准更为方便(基于它们是否是贝叶斯最优的)。因此,我们定义如下关于校准的符号:

定义 1 预测器-拒绝器的校准 我们称(h,r):XR×R是校准的,如果RLabst(h,r)=RLabst(h,r)

在本文中,我们分别考虑预测器和拒绝器的校准,这使得我们更好地理解带拒绝分类的难度来自于何处。

定义 2 预测器的分类校准 我们称h:XY是预测-校准的,如果h(x)=h(x)X上几乎处处成立。

定义 3 拒绝器的拒绝校准 我们称r:XR是拒绝-校准的,如果sign[r(x)]=sign[r(x)]对所有满足r(x)0xX成立。

通过这些定义与损失函数Labst的形式可以看到,如果h是预测-校准的且r是拒绝-校准的,则(h,r)是校准的。

如下列的代理损失LPB(h,r,x,y)

LPB(h,r,x,y)=˜ϕ(α[yh(x)r(x)])+cϕ(βr(x))

这里˜ϕϕI[z0]的凸上界。通过选择适当的参数α,β>0,Cortes等人[2]基于指数损失˜ϕ(z)=ϕ(z)=exp(z)导出了一个校准的结果。然而,Ni等人[3]指出,这个代理损失只在二分类的情况下可行,想要将这个代理损失扩展到多分类的情况下是有挑战性的,于是转而采用基于置信度分数的方法来处理多分类的情况。

本文作者尝试在多分类的情况下,为预测器-拒绝器框架下的弃权学习定义贝叶斯一致的代理损失[1]。具体地,本文作者引入了一些新的代理损失族并为其证明了强的非渐近和假设集特定的一致性保障。这些保障为弃权损失函数的估计误差提供了代理损失形式的凸上界。

在本文中,我们将在两种不同的设置下讨论预测器-拒绝器弃权代理损失,分别是单阶段两阶段。在单阶段的设置下,预测器和拒绝器同时学习;而在两阶段的设置下(在实际应用中很重要),在第一阶段中预测器使用标准的代理损失(例如交叉熵损失)来学习(例如大的预训练模型),然后在第二阶段预测器被固定,只需要学习拒绝器。

我们会为一些预测器-拒绝器框架中的弃权代理损失L证明 (H,R) - 一致性界((H,R) -consistency bound)。这些不等式给出了关于假设hH和拒绝器rR的预测器-拒绝器弃权损失Labst的上界(以它们的弃权代理损失L形式)。它们满足下列形式:

(H,R) - 一致性界

RLabst(h,r)RLabst(H,R)f(RL(h,r)RL(H,R))

这里f是非递减函数。因此,当代理估计误差RL(h,r)RL(H,R)减少到ϵ时,估计误差(RLabst(h,r)RLabst(H,R))会被f(ϵ)所界定。在这些界中会出现的一个重要的项为最小化能力差距(minimizability gap),其定义为ML(H,R)=RL(H,R)Ex[infhH,rREy[L(h,r,X,y)X=x]]。当损失函数L只依赖于h(x)r(x)(对在大多数应用中使用的损失函数都成立),且当HR包括了所有可测函数时,最小化能力差距为0。然而,它对于受限的假设集HR一般是非0的。最小化能力差距能够被近似误差 (approximation error)AL(H,R)=RL(H,R)Ex[infh,rEy[L(h,r,X,y)X=x]]所界定,这里下界取遍所有可测函数。但是,最小化能力差距是个更好的量并导出更好的理论保障。

2 单阶段预测器-拒绝器代理损失

在多分类情形下,标签Y={1,,n}n2)。我们取h(x)=arg maxyYh(x)y。则类比二分类情形,对于多分类问题,我们同样可以定义如下的预测器-拒绝器弃权损失:

Labst(h,r,x,y)=Ih(x)yIr(x)>0+c(x)Ir(x)0

注意,和之前二分类情况的不同之处在于Iyh(x)0变为了Ih(x)y,这里h(x)直接输出分类标签。设l为在标签Y上定义的0-1多分类损失的代理损失,则我们可以在此基础上进一步定义弃权代理损失L

L(h,r,x,y)=l(h,x,y)ϕ(αr(x))+ψ(c)ϕ(βr(x))

其中(x,y)X×Yψ是非递减函数,ϕ是非递增辅助函数(做为zIz0的上界),αβ为正常量。上述的L可视为Cortes等人提出的二分类弃权代理损失LPB的多分类推广版本。LPB可视为将l损失设置为基于间隔的二分类损失˜ϕ(yh(x)),并设置ψ(z)=z

Lbin(h,r,x,y)=˜ϕ(yh(x))ϕ(αr(x))+cϕ(βr(x))

最小化带正则项的Lbin,并使用基于间隔的损失˜ϕ(例如指数损失˜ϕexp(z)=exp(z)以及合页损失˜ϕhinge(z)=max{1z,0}(合页损失可参见博客《统计学习:线性支持向量机(Pytorch实现) 》)),可以在二分类情形下达到SOTA的结果。然而,我们下面会看到推广到多分类的弃权代理损失L对代理损失l的选择施加了更加严格的条件,这将诸如多分类指数损失的代理损失给排除掉了。不过,我们也会看到一些其它的损失函数满足该条件,例如多分类合页损失。下面,为了简便起见,我们主要对ϕ(z)=exp(z)进行分析,尽管相似的分析也可以应用于其它函数ϕ。我们先展示负面的结果,排除掉一些弃权代理损失L,这些弃权代理损失基于不满足特定条件的损失l

下面,我们假定假设集H对称的(symmetric)完备的(complete)。我们称一个假设集H是对称的,如果存在一个从XR的函数f的族F使得对任意xX,有{(h(x)1,,h(x)2):hH}={(f1(x),,fn(x)):f1,,fnF}。我们称一个假设集H是完备的,如果其产生的分数集合能够张成R,也即对任意(x,y)X×Y{h(x)y:hH}=R

定理 1 单阶段代理损失的负面结果 假设H是对称的与完备的,且R是完备的。若存在xX使得infhHEy[l(h,X,y)X=x]βψ(1maxyYp(y|x))α,则不存在满足属性limt0+Γ(t)=0的非递减函数Γ:R+R+使得下列(H,R)-一致性界成立:对所有hH,rR以及任意分布,有

RLabst(h,r)RLabst(H,R)+MLabst(H,R)Γ(RL(h,r)RL(H,R)+ML(H,R))

证明可以采用反证法来完成。若假设此处的(H,R) - 一致性界是有效的,则蕴含着采用单阶段代理损失学习的pointwise假设类最优预测器(best-in-class predictor)和假设类最优拒绝器(best-in-class rejector)会与采用弃权损失所学习的版本对齐。将这些显式的公式纳入代理损失的条件风险分析会导致导数检验的矛盾。

考虑定理 1,为了找到满足(H,R) - 一致性界的代理损失L,我们需要考虑满足以下条件的多分类代理损失l:对任意xX,对某些ψ(α,β)R2+

infhHEy[l(h,X,y)X=x]=βψ(1maxyYp(y|x))α

在二分类的情形下,找到满足这个条件的l较为容易,因为maxyYp(y|x)也直接地决定了其它的概率。然而,在多分类的情形下,即使maxyYp(y|x)固定,在表示infhHEy[l(h,X,y)X=x]时仍然需要考虑其它概率的不同取值。这将导致将二分类框架扩展到多分类的困难。

然而,我们会展示这个必要的条件会被三个常见的多分类代理损失l所满足。进一步地,我们将证明基于这三种l中任意一种的预测器-拒绝器代理损失L(H,R) - 一致性界。这三种损失l的定义如下(对所有hH(x,y)):

  • 平均绝对误差损失(mean absolute error loss)lmae(h,x,y)=1eh(x)yyYeh(x)y
  • 约束ρ-合页损失(constrained ρ-hinge loss)lρhinge(h,x,y)=yyϕρhinge(h(x)y),ρ>0,其中ϕρhinge(z)=max{0,1zρ}ρ-合页损失,且约束条件yYh(x)y=0
  • ρ-间隔损失(ρ-Margin loss)lρ(h,x,y)=ϕρ(ρh(x,y)),其中ρh(x,y)=h(x)ymaxyyh(x)y是置信度间隔,ϕρ(z)=min{max{0,1zρ},1},ρ>0ρ-间隔损失。

关于这里的间隔损失,可以理解为合页损失的多分类扩展(参见Crammer-Singer损失[4][5]),它旨在最大化下列预测间隔(以3个类别为例):

定理 2 单阶段代理损失的(H,R) - 一致性界 假设H是对称与完备的。则对α=βl=lmae,或者l=lρψ(z)=z,或者l=lρhingeψ(z)=z,有下列(H,R) - 一致性界对hH,rR和任意分布成立:

RLabst(h,r)RLabst(H,R)+MLabst(H,R)Γ(RL(h,r)RL(H,R)+ML(H,R))

其中对l=lmaeΓ(z)=max{2nz,nz};对l=lρΓ(z)=max{2z,z};对l=lρhingeΓ(z)=max{2nz,z}

该理论为我们在单阶段设置下描述的预测器-拒绝器代理损失提供了有力的保障。该定理证明中使用的技术是新颖的且需要对涉及pointwise假设类最优预测器和拒绝器的多种情况的仔细分析。这一分析是具有挑战性的且需要考虑具体损失函数的条件风险与校准差距。该方法由于同时在弃权场景下最小化预测器和拒绝器,整体上不同于Awasthi等人描述的标准场景[6]。下面是当HR包括所有可测函数时定理2的一个直接推论(在下面的情况下最小化能力差距MLabstML都会变为0)。

推论 3 单阶段代理损失函数的额外误差界α=βl=lmae或者l=lpψ(z)=z,或者l=lρhingeψ(z)=nz,下列额外误差界(excess error bound) 对所有hHall,rRall(这里HallRall为所有可测函数构成的集合)以及任意分布成立:

RLabst(h,r)RLabst(Hall,Rall)Γ(RL(h,r)RL(Hall,Rall))

其中Γ拥有与定理 2中相同的形式。

该推论以一个积极的方式为预测器-拒绝器框架下的多分类弃权学习提供了贝叶斯一致的代理损失。事实上,它提供了一个更强的结果,因为它为之前描述过的三种弃权代理损失给出了额外误差界。这些是比这些损失函数的贝叶斯一致性更强的保障(通过取极限操作即可得到贝叶斯一致性,也即RL(h,r)RL(Hall,Rall)0RLabst(h,r)RLabst(Hall,Rall)0)。

需要指出的是,该新颖的单阶段预测器-拒绝器代理损失可能导致一些优化的挑战。这是下列因素所导致的:优化平均绝对误差损失的困难,约束合页损失施加的限制(与在神经网络假设中做为标准使用的Softmax函数不兼容),以及ρ-间隔损失的非凸性。然而,我们的原始目标是理论分析,而且这些代理损失的意义体现在它们的创新性和强理论保障。正如推论 3所展示的,它们是首个用于多分类弃权问题的预测器-拒绝器贝叶斯一致的代理损失。

3 两阶段预测器-拒绝器代理损失

接下来,我们展示两阶段的计算方法,在这一方法中我们引入l选择更灵活的代理损失,这些代理损失具有更好的优化属性。和前面类似,我们会为它们构建(H,R) - 一致性界。两阶段场景是一个重要的场景,因为在实践中常常大的预训练的预测模型已经可利用(第一阶段),而重新训练它会产生不可接受的昂贵代价。接下来问题变为了保持第一阶段的预测模型保持不变,而随后学习一个有用的拒绝模型(第二阶段)。

两阶段的预测器-拒绝器弃权损失和我们在第2部分中提到的单阶段预测器-拒绝器弃权损失Labst不同的是,h被固定,只需要学习r,而不同于Labst中的hr同时被学习。我们设Labst,hLabst的固定预测器h的两阶段版本,定义如下:对任意rRxXyY

Labst,h(r,x,y)=Ih(x)yIr(x)>0+cIr(x)0

作者提出了一个两阶段计算方法:

  • 首先,找到一个分类器h以最小化标准多分类代理损失l
  • 其次,固定h,通过最小化代理损失Lϕ,h找到r。关于r的的代理损失函数定义如下(对所有的(x,y)):

Lϕ,h(r,x,y)=Ih(x)yϕ(r(x))+cϕ(r(x))

这里ϕ为做为zIz0上界的非递增辅助函数,也即对应在二分类中为函数r决定间隔损失lϕ(r,x,y)=ϕ(yr(x))的函数(例如指数函数ϕ(z)=exp(z)),其中y{1,+1}。该计算方法是比较直接的,因为第一阶段涉及使用标准代理损失(例如Logistic损失或带Softmax的交叉熵损失)寻找预测器的经典任务;而第二阶段也相对简单,因为h被固定,且Lϕ,h的形式也不复杂,其中ϕ可能是Logistic损失或者指数损失。需要指出的是,严格的选择上式中的示性函数对保障两阶段代理损失获益于(H,R) - 一致性界是很重要的。如果代理损失函数在第一阶段中被使用,这可能不一定满足。

需要指出的是,损失函数Labst,hLϕ,h都是弃权函数r的函数,而Labst(h,r)(H,R)的函数。

定义二值0-1分类损失lbinary0-1(r,x,y)=Iysign(r(x)),其中sign(z)=Iz>0Iz0。正如单阶段代理损失,两阶段代理损失也获益于强一致性保障。我们先展示在第二阶段中,当预测器h固定时,若lϕ满足关于二值0-1损失lbinary0-1R-一致性界,则代理损失函数Lϕ,h获益于关于Labst,hR-一致性界。

定理 4 第二阶段代理损失的R-一致性界 对于固定的预测器h,假设lϕ满足关于lbinary0-1R-一致性界,即存在非递减凹函数Γ使得对所有rR,有

Rlbinary0-1(r)Rlbinary0-1(R)+Mlbinary0-1(R)Γ(Rlϕ(r)Rlϕ(R)+Mlϕ(R))

则对于所有rR和任意分布,下列R-一致性界成立:

RLabst,h(r)RLabst,h(R)+MLabst,h(R)Γ((RLϕ,h(r)RLϕ,h(R)+MLϕ,h(R))/c)

该定理的证明包括了对于固定的预测器h,分析弃权损失和第二阶段代理损失的校准差距。这里的校准差距相较于标准设置下的更复杂,因为它考虑了条件概率、该固定预测器的误差和花费,于是因此需要不同的分析。为了构建第二阶段代理损失的R-一致性界,我们需要使用该代理损失的校准差距来构建弃权损失的校准差距的上界。然而,直接操作它们会由于其复杂形式而较为困难。不过,我们可以观察到这两种形式共享了与标准分类中校准差距的结构相似性。由上述的观察启发,我们构建了一个合适的条件分布来将这两个校准差距转换为标准形式。我们尝试利用lϕ的关于二值0-1损失的R-一致性界,以用代理函数的校准差距构建目标校准差距的上界。

HR为可测函数集的特殊情况下,定理 4中的所有最小化能力差距项消失了。因此,我们获得了如下推论。

推论5 固定预测器h,假设lϕ满足关于lbinary0-1的额外误差界,即存在非递减凹函数Γ使得对所有rRall,有

Rlbinary0-1(r)Rlbinary0-1(Rall)Γ(Rlϕ(r)Rlϕ(Rall))

于是,对所有rRall和任意分布,下列额外误差界成立:

RLabst,h(r)RLabst,h(Rall)Γ((RLϕ,h(r)RLϕ,h(Rall))/c)

我们接下来陈述关于弃权损失函数Labst的整个两阶段方法的(H,R) - 一致性界。设l0-1为多分类0-1损失:l0-1(h,x,y)=Ih(x)y。我们接下来考虑是弃权正规(regular for abstention) 的假设集合R,也即使得对任意xX,存在f,gR满足f(x)>0g(x)0。如果R是弃权正规的,则对于任意x,既存在一个可以接受的选择也存在一个可以拒绝的选择。

定理6 两阶段方法的(H,R) - 一致性界 假设R是正规的。假设l满足关于l0-1H - 一致性界,lϕ满足关于lbinary0-1R - 一致性界,即存在非递减凹函数Γ1Γ2使得对于所有的hHrR,有

Rl0-1(h)Rl0-1(H)+Ml0-1(H)Γ1(Rl(h)Rl(H)+Ml(H))Rlbinary0-1(r)Rlbinary0-1(R)+Mlbinary0-1(R)Γ2(Rlϕ(r)Rlϕ(R)+Mlϕ(R))

于是,下列的(H,R) - 一致性界对所有hH,rR和任意分布成立:

RLabst(h,r)RLabst(H,R)+MLabst(H,R)Γ1(Rl(h)Rl(H)+Ml(H))+(1+c)Γ2((RLϕ,h(r)RLϕ,h(R)+Mlϕ,h(R))/c)

其中常数因子(1+c)1cΓ2是线性的时可以被移除。

和前面类似,当HR为可测函数族时,下列关于额外误差界的推论成立。

推论 7 假设l满足关于l0-1的额外误差界,lϕ满足关于lbinary0-1的额外误差界,即存在非递减凹函数Γ1Γ2使得对于所有的hHallrRall,有

Rl0-1(h)Rl0-1(Hall)Γ1(Rl(h)Rl(Hall))Rlbinary0-1(r)Rlbinary0-1(Rall)Γ2(Rlϕ(r)Rlϕ(Rall))

于是,下列额外误差界对于所有hHallrRall和任意分布成立:

RLabst(h,r)RLabst(Hall,Rall)Γ1(Rl(h)Rl(Hall))+(1+c)Γ2((RLϕ,h(r)RLϕ,h(Rall))/c)

其中常数因子(1+c)1cΓ2是线性的时可以被移除。

这些结果为两阶段设置下的代理损失提供了强理论保障。此外,l在单阶段设置下的选择受特定条件的约束,而在两阶段设置下可以被更加灵活地选择。特别地,它可以被选择为Logistic损失(或带Softmax的交叉熵损失),这不但更易于优化,而且能够更好地适配于复杂的神经网络。在第二阶段,公式比较直接,函数ϕ的选择是灵活的,这将导出关于拒绝函数r的简单的光滑凸优化问题。此外,第二阶段将过程进行了简化:h做为常量,只有拒绝器被优化。该方法能够增强优化效率。

参考

  • [1] Mao A, Mohri M, Zhong Y. Predictor-rejector multi-class abstention: Theoretical analysis and algorithms[C]//International Conference on Algorithmic Learning Theory. PMLR, 2024: 822-867.
  • [2] Cortes C, DeSalvo G, Mohri M. Boosting with abstention[J]. Advances in Neural Information Processing Systems, 2016, 29.
  • [3] Ni C, Charoenphakdee N, Honda J, et al. On the calibration of multiclass classification with rejection[J]. Advances in Neural Information Processing Systems, 2019, 32.
  • [4] Han Bao: Learning Theory Bridges Loss Functions
  • [5] Crammer K, Singer Y. On the algorithmic implementation of multiclass kernel-based vector machines[J]. Journal of machine learning research, 2001, 2(Dec): 265-292.
  • [6] Awasthi P, Mao A, Mohri M, et al. Multi-Class H-Consistency Bounds[J]. Advances in neural information processing systems, 2022, 35: 782-795.
posted @   orion-orion  阅读(90)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示