Typesetting math: 100%

分布式多任务学习:代理损失函数进行任务分解

1 代理损失函数——一种并行化拆解技巧

我们在《基于正则表示的多任务学习》中提到,实现多任务学习的一种传统(非神经网络)的方法为增加一个正则项[1][2][3]

minWKk=1[1nknki=1L(yki,f(xki;wk))]+λg(W)=Kk=1lk(wk)+λg(W)=f(W)+λg(W)minWKk=1[1nknki=1L(yki,f(xki;wk))]+λg(W)=Kk=1lk(wk)+λg(W)=f(W)+λg(W)(1)

目标函数中的f(W)f(W)很容易并行化拆解,但是一般g(W)g(W)就很难并行化了,那么如何解决这个问题呢?答案是运用一个可以分解的代理损失函数来替换掉原始的目标函数。我们接下来就以论文《Parallel Multi-Task Learning》[4](zhang 2015c等人)为例来介绍该思想。该论文MTLR模型[5](zhang 2015a)的基础上利用FISTA算法设计代理损失函数,该代理函数可以依据学习任务进行分解,从而并行计算。

2 基于正则化的多任务学习(MTLR)算法回顾

给定KK个任务{Tk}Kk=1{Tk}Kk=1,每个任务都有一个训练集Dk={(xki,yki)nki=1}Dk={(xki,yki)nki=1}。我们现在考虑以下形式的目标函数:

minW,bKk=1[1nknki=1L(yki,wk,ϕ(xki)+b)]+λ2tr(WΩWT)minW,bKk=1[1nknki=1L(yki,wk,ϕ(xki)+b)]+λ2tr(WΩWT)(2)

这里的ϕ()ϕ()是一个和核函数k(,)k(,)相关的特征映射,这里ϕ(x1)Tϕ(x2)=k(x1,x2)ϕ(x1)Tϕ(x2)=k(x1,x2)L(,)L(,)是损失函数(比如对于分类问题的hinge losshinge loss 和对于回归问题的 ϵ-insentive lossϵ-insentive loss。式(2)(2)的第一项是所有任务的经验损失函数,第二项基于WΩWΩ来建模任务间的关系。根据论文[5]ΩΩ是一个正定(Positive definite, PD)矩阵,它用来描述任务两两之间关系的精度矩阵(协方差矩阵ΣΣ的逆)。如果损失函数是凸的且ΩΩ正定,那么目标函数(2)(2)关于WbWb是联合凸(jointly convex)的。为了体现目标函数(2)(2)和单任务核方法的关系,我们这里只考虑ΩΩ是对角矩阵的情况。在此情况下,任务两两之间没有关系,问题(2)(2)也退化为了多个单任务模型(每个模型对应一个任务)。因此,问题(2)(2)可以被视为单任务正则化模型的多任务扩展。在问题(2)(2)中,λ2tr(WΩWT)λ2tr(WΩWT)不影响我们的并行算法设计,这是非常好的。而问题(2)(2)总是能够加速问题的学习,当使用特定的优化程序如论文[5]和论文[6]一样,根据过去的研究这些方法有很快的收敛率,不管正则项是什么。

在问题(2)(2)中有许多损失函数可供使用,比如hinge losshinge lossϵinsensitive lossϵinsensitive losssquare losssquare loss,下面我们主要就采用这三种损失函数,后面我们会分别给出问题(2)(2)关于这三个损失函数的对偶形式。

3 并行多任务学习算法

3.1 FISTA迭代算法

下面我们就给出当使用不同的损失函数时问题(2)(2)的并行求解算法。因为我们的求解算法是基于FISTA迭代的,我们先来看FISTA迭代算法。
FISTA迭代算法[7]是一个加速梯度下降方法,用于求解一个类似于下面这种形式的复合凸目标函数(compositely convex objective function):

minθCθF(θ)=f(θ)+g(θ)minθCθF(θ)=f(θ)+g(θ)(3)

这里θθ是指模型的参数集合,f(θ)f(θ)是凸的且它的梯度有LipschitzLipschitz连续性,凸函数g(θ)g(θ)有着简单的且易分解(并行)的结构,CθCθ是指θθ的定义域。FISTA算法最新构建代理损失函数Ql(θ,ˆθ)Ql(θ,^θ)如下:

QL(θ,ˆθ)=g(θ)+f(ˆθ)+(θˆθ)Tθf(ˆθ)+L2||θˆθ||22QL(θ,^θ)=g(θ)+f(^θ)+(θ^θ)Tθf(^θ)+L2||θ^θ||22(4)

这里θf(ˆθ)θf(^θ)表示f(θ)f(θ)θ=ˆθθ=^θ点的梯度,LLf()f()梯度的LipschitzLipschitz常量,接着我们优化关于θθ的函数QL(θ,ˆθ)QL(θ,^θ),约束为θCθθCθ。函数QL(θ,ˆθ)QL(θ,^θ)关于θθ的优化器由qL(ˆθ)qL(^θ)表示。

FISTA算法伪代码如下图所示:

FISTA算法伪代码

可以看到第1717步和1818步在θθ能够被划分为许多部分的情况下可以轻易并行。但目前的问题是如何并行化算法步骤11111313

3.2 将目标函数转换为对偶问题

当使用hingeϵ-intensivehingeϵ-intensivesquresqure损失函数时,我们需要用FISTAFISTA算法优化其对偶问题。下面我们分别说明得到这三个损失函数对应目标函数的对偶问题,后面我们会在此基础上进行并行化。

3.2.1 Hinge Loss

(1)转为对偶形式 我们将Hinge Loss函数Lh(y,y)=max(1yy,0)Lh(y,y)=max(1yy,0)代入式(2)的优化问题,并将无约束优化转为有约束优化可得到:

minW,b,ηλ2tr(WΩWT)+Kk=11nknki=1ηkis.t.yki(wk,ϕ(xki)+bk)1ηki,ηki0minW,b,ηλ2tr(WΩWT)+Kk=11nknki=1ηkis.t.yki(wk,ϕ(xki)+bk)1ηki,ηki0(5)

这里η=(η11,...,ηKnK)Tη=(η11,...,ηKnK)T。引入非负的Lagrange乘子{αki}{αki}{βki}{βki},我们可以得到问题(5)(5)的对偶形式如下:

minα12λαTPαKk=1nki=1αkis.t.nki=1αkiyki=0  k=1,2,...,K,i=1,2,...,nk,0αki1nkminα12λαTPαKk=1nki=1αkis.t.nki=1αkiyki=0  k=1,2,...,K,i=1,2,...,nk,0αki1nk(6)

这里我们说明一下矩阵PP的含义,设σijσij是任务关系协方差矩阵ΣΣ的第(i,j)(i,j)个元素,KK是一个n×nn×n的矩阵,它的第(Iab,Icd)(Iab,Icd)个元素是σack(xab,xcd)σack(xab,xcd),这里Iki=i+k1l=1nlIki=i+k1l=1nl计算在所有任务的训练数据中的xkixki的下标。指逐元素乘积操作,这里有P=K(yyT)P=K(yyT)。这里我们定义函数kMT(,)kMT(,)kMT(xqi,xrj)=σqrk(xqi,xrj)kMT(xqi,xrj)=σqrk(xqi,xrj)

用来构造矩阵KK。很容易证明这是一个核函数。所以我们称kMT(,)kMT(,)是一个多任务核函数,将KK称为多任务核矩阵。

3.2.2 ϵϵ - Insensitive Loss

接下来我们讨论将ϵϵ insensitive loss函数

Lϵ(y,y)={0|yy|ϵ|yy|ϵ其他Lϵ(y,y)={0|yy|ϵ|yy|ϵ

代入问题(2)(2)进行优化。我们再引入一些松弛变量,问题(2)(2)可被转化为:

minW,b,η,τKk=11nknki=1(ηki+τki)+λ2 tr(WΩWT)s.t.nki0,wTkϕ(xki)+bkykiϵ+ηki  τki0,ykiwTkϕ(xki)bkϵ+τkiminW,b,η,τKk=11nknki=1(ηki+τki)+λ2 tr(WΩWT)s.t.nki0,wTkϕ(xki)+bkykiϵ+ηki  τki0,ykiwTkϕ(xki)bkϵ+τki(7)

这里η=(η11,...,ηKnK)Tη=(η11,...,ηKnK)Tτ=(τ11,...,τKnK)Tτ=(τ11,...,τKnK)T
我们接下来引入Lagrange乘子α=(α11,...,αKnK)Tα=(α11,...,αKnK)Tβ=(β11,...,βKnK)Tβ=(β11,...,βKnK)T,进一步得到问题(7)(7)的对偶问题:

minα,β12λ(αβ)TK(αβ)+ϵ(α+β)T1+yT(αβ)s.t.nki=1(αkiβki)=0k=1,2,...,K,i=1,2,...,nk,0αki,βki1nkminα,β12λ(αβ)TK(αβ)+ϵ(α+β)T1+yT(αβ)s.t.nki=1(αkiβki)=0k=1,2,...,K,i=1,2,...,nk,0αki,βki1nk(8)

这里11表示一个元素全为1的合适大小的向量或者矩阵,KK表示由等式(8)(8)的多任务核函数kMT(,)kMT(,)构成的矩阵。这里y=(y11,...,yKnK)Ty=(y11,...,yKnK)T

3.2.3 Square loss:

我们将square loss代入问题(2)(2),得到以下优化问题:

minW,b,{ηki}Kk=11nknki=1(ηki)2+λ2 tr(WΩWT)s.t.nki=ykiwTkϕ(xki)bkminW,b,{ηki}Kk=11nknki=1(ηki)2+λ2 tr(WΩWT)s.t.nki=ykiwTkϕ(xki)bk(9)

引入Lagrange乘子{αki}{αki},我们就可以得到问题(9)(9)的对偶形式:

minα12λαTQαKk=1nki=1αkiykis.t.nki=1αki=0 k=1,2,...,K,i=1,2,...,nk,0αki1nkminα12λαTQαKk=1nki=1αkiykis.t.nki=1αki=0 k=1,2,...,K,i=1,2,...,nk,0αki1nk(10)

这里αk=(α11,...,αKnK)Tαk=(α11,...,αKnK)T。这里Q=K+λ2ΛQ=K+λ2ΛΛΛ是一个对角矩阵,相应的数据点属于第kk个任务时其对角元素为nknk

注意,后面我们会发现三个损失函数对应的对偶形式都有着相似的形式而且和单任务对偶形式的主要不同点都在于线性不等式约束。也就是说,在单任务对偶形式中,只有一个涉及Lagrange乘子的线性不等式约束;但是在多任务环境下,有KK个线性不等式约束,每个不等式都由一个任务的Lagrange乘子组成。有趣的是,这种差别决定了我们后面设计的并行算法。

3.3 将对偶问题的求解并行化

接下来我们需要展示应用FISTA算法并行化求解(6)(6),其他损失函数同理。我们定义θ=αθ=αϕ=ˆαϕ=^αf(α)=12λαTPαf(α)=12λαTPαg(α)=Kk=1nki=1αkig(α)=Kk=1nki=1αki,定义域Cα={α|nki=1αkiyki=0(k=1,2,..,K,i=1,2,...,nk,0αki1nk)}Cα={α|nki=1αkiyki=0(k=1,2,..,K,i=1,2,...,nk,0αki1nk)}。下面我们来看如何并行化算法步骤11111313

f(α)f(α)关于αα的二阶导数2f(α)2f(α)是我们这里的1λP1λP。我们用||||||||表示矩阵的l2l2范数,易得||P||2InP||P||2InP是一个半正定矩阵。所以f(α)f(α)的最小LipschitzLipschitz常量是1λ||P||21λ||P||2(L1λ||P||2L1λ||P||2)。当nn非常大时,计算||P||2||P||2非常耗时,我们下面会展示如何并行地计算它。

给定LL,我们能够优化关于αα的函数QL(α,ˆα)QL(α,^α),这也是步骤11或13要求解的(并行地)。特别地,步骤11或13要求解的优化问题可以被描述为:

minαL2||αˆα||22+1λαTPˆαKk=1nki=1αkis.t.nki=1αkiyki=0  k=1,2,...,K,i=1,2,...,nk,0αki1nkminαL2||α^α||22+1λαTP^αKk=1nki=1αkis.t.nki=1αkiyki=0  k=1,2,...,K,i=1,2,...,nk,0αki1nk(11)

该问题可以被分解为TT个独立的子问题,第tt个子问题为:

minαknki=1(L2(αki)2αkiαkj)s.t.nki=1αkiyki=ck  i=1,2,...,nk,ραkidkminαknki=1(L2(αki)2αkiαkj)s.t.nki=1αkiyki=ck  i=1,2,...,nk,ραkidk(12)

这里αk=(αk1,...,αknk)Tαk=(αk1,...,αknk)Taki=Lˆαkj+11λˆpkiaki=L^αkj+11λ^pkiˆpki^pkiPˆαP^α中与xkixki对应的元素。ck=0,ρ=0,dk=1nkck=0,ρ=0,dk=1nk

问题(12)(12)是一个二次规划(quadratic programming, QP)问题,我们能够不用任何QP求解器,在O(nk)O(nk)的时间内用拉格朗日乘子法求解。正如问题(11)(11)所示,FISTA算法的每一轮迭代都需要计算PˆαP^α以决定QL(α,ˆα)QL(α,^α)。如果我们直接解任务(12)(12)αα会完全和之前的估计不同,且计算PˆαP^α会花费O(n2)O(n2),当nn很大时计算量太大。所以这里我们希望采取SMO算法的思想,只更新部分的αα元素,这样计算PˆαP^α的时间复杂度减少到O(n)O(n)。(因为我们只需要关心变化的元素)

参考

  • [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
  • [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
  • [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
  • [4] Zhang Y. Parallel multi-task learning[C]//2015 IEEE International Conference on Data Mining. IEEE, 2015: 629-638.
  • [5] Zhang Y, Yeung D Y. A convex formulation for learning task relationships in multi-task learning[J]. arXiv preprint arXiv:1203.3536, 2012.
  • [6] Zhang Y, Yeung D Y. A regularization approach to learning task relationships in multitask learning[J]. ACM Transactions on Knowledge Discovery from Data (TKDD), 2014, 8(3): 1-31.
  • [7] A. Beck and M. Teboulle, “A fast iterative shrinkagethresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, 2009
  • [8] 杨强等. 迁移学习[M].机械工业出版社, 2020.
posted @   orion-orion  阅读(571)  评论(1编辑  收藏  举报
编辑推荐:
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示