分布式多任务学习:同步和异步优化算法
1 分布式多任务学习(Multi-task Learning, MTL)简介
我们在上一篇文章《基于正则表示的多任务学习》中提到,实现多任务学习的一种传统的(非神经网络的)方法为增加一个正则项[1][2][3]:
其中g(W)g(W)编码了任务的相关性(多任务学习的假定)并结合了KK个任务;λλ是一个正则化参数,用于控制有多少知识在任务间共享。在许多论文中,都假设了损失函数f(W)f(W)是凸的,且是L-LipschitzL-Lipschitz可导的(对L>0L>0),然而正则项g(W)g(W)虽然常常不满足凸性(比如采用矩阵的核范数),但是我们认为其实接近凸的,因此对于式(1)(1)可以采用近端梯度算法(proximal gradient methods)[4]来求解(在标准近端梯度法中,默认g(W)g(W)是不可微的凸函数)。
不过现实情况比较复杂。当任务数量很大时,多任务学习的计算复杂度很高,这可能要求我们用多CPU/多GPU对学习算法进行加速;又或者(也是更为常见的情况,尤其在联邦学习中),当数据量很大时,数据经常会分片存储在不同的计算机甚至是不同的计算中心。比如,学习任务经常会涉及到不同的样本集(用于学习不同的任务):D1,...,DKD1,...,DK,这些样本集进场会存储在不同的地方。比如如果我想用不同医院的病例样本集进行多任务学习,那么不同医院的数据肯定各自存储在不同地方。不管是出于网络传输带宽考虑还是数据隐私考虑,想要将所有场所的数据集中在一起然后跑最优化算法显然不太现实(即使数据已经脱敏,大规模转移病人数据仍然是个很有争议的问题)。
以上的两种需求要求我们尽量使KK个任务的梯度的计算分摊到KK个不同的工作节点(worker)上。这样对MTL设计分布式算法就显得非常重要了,分布式算法旨在尽量将耗时的计算放在各分界点本地进行,然后再通过网络传输到中心节点。现实中的训练数据会非常庞大且位于不同的数据中心,我们不能将数据收集起来再训练 ,必须要将数据存放在各节点就地使用。
实际上由于正则项的存在和损失函数的复杂性,我们需要非常仔细地设计分布式多任务学习算法,在保证任务得到划分的同时而尽量不影响优化算法最终的收敛。
2 MTL的同步(synchronized)分布式数值优化算法
我们将会从MTL的单机数值优化方法开始,逐步说明分布式数值优化的必要性并介绍它的一种主要实现手段——同步分布式优化算法。
我们先来看单机数值优化,由于g(W)g(W)正则项的不光滑性,MTL的目标函数常采用基于近端梯度的一阶优化方法进行求解,包括FISTA[5](近端梯度下降法的一个典型变种), SpaRSA[6]以及最近提出的二阶优化方法PNOPT[7]。下面我们简要回顾一下在这些方法中涉及到的两个关键计算步骤:
(1) 梯度计算(gradient computing) 设第tt迭代步的参数矩阵为W(t)W(t),目标函数光滑部分f(W(t))f(W(t))的梯度由每个任务的损失函数单独计算梯度后拼接而得:
(2) 近端映射(proximal mapping) 在梯度更新后,我们会计算
此处ηη是迭代步长。不过请注意,此处只为完成了f(W)+g(W)f(W)+g(W) 中可微部分f(W)f(W)的求导,ˆW^W尚不能做为我们下一步的搜索点,下一步的搜索点会经过近端映射Prox(ˆW;η,λ,g)Prox(^W;η,λ,g)获得,该近端映射等价于求解下列优化问题:
这样,我们就得到了下一步的搜索点W(t+1)W(t+1)。
这里我们不失一般性,我们假设我们的数据集D1,...,DKD1,...,DK分散存储在一个用星形网络连接的计算机集群中。每个单独的计算机系统我们称之为节点(node),工作节点(worker)或智能体(agent)(在多智能体系统中)。第kk个节点对于任务kk的数据DkDk拥有完全访问权,并能够进行数值计算(比如计算第kk个任务的梯度∇lk(wk∇lk(wk)。我们假定有一个中心节点(central server)能够收集所有任务节点(task agents)的数据,并进行近端映射操作。
我们接下来看如何分布式并行。因为KK个任务的独立性,可以让第kk个任务节点存储w(t)kw(t)k,然后负责计算梯度∇lk(w(t)k)∇lk(w(t)k),这样就很容易地并行化了。然后我们收集每个任务的梯度向量∇lk(w(t)k)∇lk(w(t)k)到中心节点并拼接得到∇f(W(t))∇f(W(t)),然后计算ˆW^W,最后经过近端映射操作得到W(t+1)W(t+1)。然后再将W(t+1)W(t+1)拆分为w(t+1)1,...,w(t+1)Kw(t+1)1,...,w(t+1)K分别发送到KK个任务节点,进行下一轮的迭代。整个分布式并行算法如下图所示(图片来自王树森老师的YouTube课程并行计算与机器学习[16]):

因为必须要所有任务节点的梯度计算并收集完毕后,主节点才能进行下一步操作,所以上面这种方法被称为同步的(synchronized)。同步方法的最大弊端就是如果有一个或多个任务节点网络传输带宽过低,或者直接down掉,其他任务节点都会停下来等待(因为拿不到下一轮的数据)。因为多数一阶优化算法都需要经过很多轮迭代才能够收敛到一个特定的精度,在同步数值优化算法中的等待会造成不能容忍的算法运行时间和运算资源的极大浪费。
3 MTL的异步(asynchronized)分布式数值优化算法
上面提到的同步数值优化算法可能让一些读者想到MapReduce计算架构,这种架构很少用于迭代算法。比如我们在深度学习的训练中多采用参数服务器(Parameter Server)架构,这是一种异步数值优化的架构。在多任务学习的领域,也有学者提出了异步数值优化算法,接下来我们以《Asynchronous Multi-Task Learning》[8](IM Baytas等,2016)这篇论文为例,来介绍MTL的异步数值优化算法。
在本篇论文的异步数值优化算法中,中心节点只要收到了来自一个任务节点的已经算好的梯度,就会马上对模型的参数矩阵WW进行更新,而不用等待其他任务节点完成计算。中心节点和任务节点都会在内存中维护一份WW的拷贝,任务节点之间的拷贝可能会各不相同。AMTL(Asynchronized Multi-task learning)的收敛率分析可以参照另外两篇介绍ARock[9]计算框架介绍Tmac[10]计算框架的论文(这两篇论文对Krasnosel’skii-Mann
(KM) 迭代方法进行了改造,增加了异步并行坐标更新(asynchronous parallel coordinate update)的特性)。我们称一个任务节点被激活,当它进行(梯度)计算并与中心节点通信以进行更新。《Asynchronous Multi-Task Learning》这篇论文也提出了一个异步并行的框架,该框架基于以下关于激活率(activation rate)的假设:
假设1: 所有任务节点服从独立的泊松过程并且有相同的激活率。
该假设可以得到一个有用的结论,如果不同的任务节点的激活率不同,我们理论上可以调整迭代步长ηη来调整迭代结果:如果任务节点的激活率很大,那么该任务节点被激活的可能性就会很大,从而我们应该降低
该任务节点对应的迭代步长ηη(注意:因为是异步算法,每个任务节点都有其对应的迭代步长)。该论文提出了的一个动态迭代步长策略,具体细节在此略过。
接下来的推导会用到基于算子做优化的思想,我们这里做一下简要介绍。
很多优化问题可以转换为一个求映射的零点问题,即求一个xx使得映射A(x)=0A(x)=0满足:
比如我们求无约束优化问题,其最优解等价于求解梯度等于0,这里AA就为求梯度;对于约束优化问题,我们可以转化为一个无约束对偶问题,这时的AA就是求对偶问题的梯度。
而求解问题(5)(5)的方法就是不动点迭代,也就是找到一个特定的算子TT,迭代地寻找解:
问题(5)(5)的解是算子TT的稳定点x∗x∗,满足x∗=T(x∗)x∗=T(x∗)。
接下来我们讨论算子TT该怎么定。目前已经提出了一下几种最常见和实用的算子:
(1) 前向算子:T=I−ηAT=I−ηA
考虑凸问题minxf(x)minxf(x),该问题可转化为求解∇f(x)=0∇f(x)=0,令A(x)=∇f(x)A(x)=∇f(x),我们应用前向算子去求解该方程
聪明的你应该发现,这就是梯度迭代。
(2) 后向算子: T=(I+ηA)−1T=(I+ηA)−1
还是该问题,我们运用该算子有:
这对应的就是我们近端迭代步骤:
(3) 前向后向分类分裂:T=(I+ηB)−1(I−ηA)T=(I+ηB)−1(I−ηA)(即前面这两个算子的结合)
考虑优化问题minxf(x)+g(x)minxf(x)+g(x),其中f(x)f(x)光滑,梯度为f(x)f(x),而g(x)g(x)不光滑,次梯度为∂g(x)∂g(x)。该优化问题等价于找到xx满足0∈∇f(x)+∂f(x)0∈∇f(x)+∂f(x)。我们令A(x)=∇(x)A(x)=∇(x),B(x)=∂g(x)B(x)=∂g(x),这要就可以运用前向后向算子:
由前面的讨论知,x(t+1/2)=(I−η∇f)(x(t))x(t+1/2)=(I−η∇f)(x(t))是一个梯度迭代,x(t+1)=(I+η∂g)−1(x(t+1/2))x(t+1)=(I+η∂g)−1(x(t+1/2))是一个近端迭代。所以结合起来就能得到:
这就是近端梯度迭代算法。
关于算子做优化我们就介绍到这里,其他详见算子优化相关书籍[13]和知乎文章[14])。
该论文就使用的是前向-后向算子分裂方法[11][12]来求解目标函数(1)(1)。前向后向迭代如下:
该迭代对η∈(0,2/L)η∈(0,2/L)会收敛到解。前面我们提到∇f(W)∇f(W)是可分的,比如可以写成∇f(W)=(∇l1(w1),...,∇lK(wK))∇f(W)=(∇l1(w1),...,∇lK(wK)),且此处的前向算子I−η∇fI−η∇f也是可分的。不过这里的后向算子(I+ηλ∂g)−1(I+ηλ∂g)−1是不可分的,将导致后面无法并行。因此我们不能直接在前向-后向迭代上应用论文[9]提到的KM坐标更新法。不过,如果我们转换前向和后向的顺序,我们可以得到下列的后向-前向迭代:
这里我们使用一个辅助矩阵V∈Rd×K来在更新中替代W,这是因为前向-后向迭代和后向-前向迭代中的更新变量是不一样的。因此,由V∗得到W∗需要一个额外的后向迭代步骤。之后我们就可以在后向-前向迭代的基础上,按照论文[9]提出的坐标下降策略,从{1,2,...,K}中随机采样一个任务索引k,来对任务块vk(代指和任务k有关的变量)进行坐标更新了。更新步骤如下所示:
这里vk∈Rd是任务k中wk相应的辅助变量。注意想要更新一个任务块vk只需要在一个任务块上进行一个完整的后向步骤和前向步骤。再给出整体的AMTL算法之前,我们先给出算法的迭代步骤。该算法的迭代步遵循的为在论文[9]中讨论的基于坐标更新的KM迭代。KM迭代是一种代替不动点迭代x(t+1)=T(x(t))的方法,它的形式为:x(t+1)=x(t)+ηr(Tx(t)−x(t))。而论文[9]则更进一步,采用其坐标下降形式:
这里η是随机向量。i为从{1,2....n}中随机采样的特征索引(采到1,2,..,n的概率可以不同,不过我们一般取均等概率),n为x的维度。至于该迭代式的推导过程具体可参照论文[9]。
在这个问题中,我们设Proxηλg(ˆv(t))为后向映射,我们有:
这就是本篇论文提到的迭代式。注意,Proxηλg(V(t))需要在主节点计算完毕,然后将(Proxηλg(V(t)))k发送给任务节点k。
值得一提的是,机器学习中一般根据子问题的复杂性来选择前向-后向迭代和后向-前向的迭代。如果数据集(xk,yk)很大,此时后向步骤相比前向步骤更容易计算,我们会使用后向-前向迭代来进行坐标更新。具体到MTL的应用中,后向迭代步骤由式(4)的近端映射给出,而这通常有解析解(比如迹范数的奇异值的软阈值),适合先进行。另一方面,式(2)的梯度计算是典型的耗时瓶颈(尤其是数据集很大时),适合后进行。因此后向-前向迭代为分布式MTL提供了一个更为高效的优化框架。最后,我们注意到后向前向迭代当η∈(0,2/L)时是一个non-expansive 算子,因为前向和后向步骤都是non-expansive的。
整个异步多任务学习框架如下图所示:

在本篇AMTL论文中,任务节点并不共享内存(注意,在文献[9]中,各任务节点可访问一个共享内存),任务节点间不能通信,但它们都各自与主节点连接并能与之通信。任务节点和主节点之间的通信只有向量vk,相较各任务节点上存储的本地数据Dk很小。每个任务节点负责计算前向步骤;而主节点负责计算后向步骤,一旦更新的梯度从任务节点传来,就进行近端映射(近端映射也能够在多轮梯度更新后才进行,取决于梯度更新的速度)。因为每个任务节点只需要和该任务节点相关的任务块,故本篇论文进一步减少了任务节点和主节点之间的通信代价。

上面这幅图进一步描述了AMTL中的异步更新机制,包括中心节点和任务节点分别执行后向和前向步骤的顺序。在t1时刻,任务节点2从中心节点接收了已完成近端映射的参数(Proxηλg(V(t)))2,之后在任务节点2上的前向(梯度)计算步骤就会马上启动。在伪代码步骤第6行所示的任务梯度下降更新完成后,任务2的参数v2会被送回中心节点。当中心节点收到参数后,它会开始对整个(即包括所有任务的)参数矩阵进行近端映射。
然而,这个算法却会有潜在的不一致性(inconsistency)问题。如图所示,当t2与t3之间,即任务节点2在执行计算时,任务节点4已经将其计算好的参数发送到中心节点并触发了近端映射。因此,中心节点的参数矩阵在任务节点2计算梯度时,就因响应任务节点4而更新。之后当任务节点2将算好的参数送到中心节点时,近端映射只能在不一致的数据上进行计算(数据由来自任务节点2的参数和之前已更新的参数混合而成)。 同样,任务节点4在t3时刻收到参数并完成计算后,此时中心节点的参数已经被更新(因为任务节点2在t4和t5之间已触发近端映射),后面也会产生同样的问题。
为什么会有这种不一致性呢?这是因为AMTL中的数据读取是没有加内存锁的。因此,对于异步坐标更新模式,从中心节点读参数向量时会有不一致性的问题。这种由于后向迭代步骤产生的不一致性已经被论文考虑在了收敛率分析中,具体细节大家可以参见论文。
最后,这篇论文在异构网络环境下的版本代码已开源在Github上(链接: https://github.com/illidanlab/AMTL ),大家可前往学习。
参考
- [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
- [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
- [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
- [4] Ji S, Ye J. An accelerated gradient method for trace norm minimization[C]//Proceedings of the 26th annual international conference on machine learning. 2009: 457-464.
- [5] A. Beck and M. Teboulle, “A fast iterative shrinkage-thresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, vol. 2, no. 1, pp. 183–202, 2009.
- [6] S. J. Wright, R. D. Nowak, and M. A. Figueiredo, “Sparse reconstruction by separable approximation,” IEEE Transactions on Signal Processing, vol. 57, no. 7, pp. 2479–2493, 2009.
- [7] J. D. Lee, Y. Sun, and M. A. Saunders, “Proximal newton-type methods for minimizing composite functions,” SIAM Journal on Optimization, vol. 24, no. 3, pp. 1420–1443, 2014.
- [8] Baytas I M, Yan M, Jain A K, et al. Asynchronous multi-task learning[C]//2016 IEEE 16th International Conference on Data Mining (ICDM). IEEE, 2016: 11-20.
- [9] Z. Peng, Y. Xu, M. Yan, and W. Yin, “ARock: An algorithmic framework for asynchronous parallel coordinate updates,” SIAM Journal on Scientific Computing, vol. 38, no. 5, pp. A2851–A2879, 2016.
- [10] B. Edmunds, Z. Peng, and W. Yin, “Tmac: A toolbox of modern asyncparallel, coordinate, splitting, and stochastic methods,” UCLA CAM Report 16-38, 2016.
- [11] P. L. Combettes and V. R. Wajs, “Signal recovery by proximal forwardbackward splitting,” Multiscale Modeling & Simulation, vol. 4, no. 4, pp. 1168–1200, 2005.
- [12] Z. Peng, T. Wu, Y. Xu, M. Yan, and W. Yin, “Coordinate-friendly structures, algorithms and applications,” Annals of Mathematical Sciences and Applications, vol. 1, pp. 57–119, 2016.
- [13] Bauschke H H, Combettes P L. Convex analysis and monotone operator theory in Hilbert spaces[M]. New York: Springer, 2011.
- [14] https://zhuanlan.zhihu.com/p/150605754
- [15] 杨强等. 迁移学习[M].机械工业出版社, 2020.
- [16] 王树森YouTube课程:并行计算与机器学习
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~