分布式多任务学习:代理损失函数进行任务分解
1 代理损失函数——一种并行化拆解技巧
我们在《基于正则表示的多任务学习》中提到,实现多任务学习的一种传统(非神经网络)的方法为增加一个正则项[1][2][3]:
目标函数中的\(f(\mathbf{W})\)很容易并行化拆解,但是一般\(g(\mathbf{W})\)就很难并行化了,那么如何解决这个问题呢?答案是运用一个可以分解的代理损失函数来替换掉原始的目标函数。我们接下来就以论文《Parallel Multi-Task Learning》[4](zhang 2015c等人)为例来介绍该思想。该论文MTLR模型[5](zhang 2015a)的基础上利用FISTA算法设计代理损失函数,该代理函数可以依据学习任务进行分解,从而并行计算。
2 基于正则化的多任务学习(MTLR)算法回顾
给定\(K\)个任务\({\{\mathcal{T}_k\}}_{k=1}^K\),每个任务都有一个训练集\(\mathcal{D}_k = {\{(\bm{x}_{ki}, y_{ki})}_{i=1}^{n_k}\}\)。我们现在考虑以下形式的目标函数:
这里的\(\phi(\cdot)\)是一个和核函数\(k(\cdot, \cdot)\)相关的特征映射,这里\(\phi(\bm{x}_1)^T\phi(\bm{x}_2)=k(\bm{x}_1, \bm{x}_2)\)。\(L(\cdot, \cdot)\)是损失函数(比如对于分类问题的\(\text{hinge loss}\) 和对于回归问题的 \(\epsilon \text{-insentive loss}\)。式\((2)\)的第一项是所有任务的经验损失函数,第二项基于\(\mathbf{W}和\Omega\)来建模任务间的关系。根据论文[5],\(\Omega\)是一个正定(Positive definite, PD)矩阵,它用来描述任务两两之间关系的精度矩阵(协方差矩阵\(\Sigma\)的逆)。如果损失函数是凸的且\(\Omega\)正定,那么目标函数\((2)\)关于\(\mathbf{W}和\bm{b}\)是联合凸(jointly convex)的。为了体现目标函数\((2)\)和单任务核方法的关系,我们这里只考虑\(\Omega\)是对角矩阵的情况。在此情况下,任务两两之间没有关系,问题\((2)\)也退化为了多个单任务模型(每个模型对应一个任务)。因此,问题\((2)\)可以被视为单任务正则化模型的多任务扩展。在问题\((2)\)中,\(\frac{\lambda}{2}\text{tr}(\textbf{W}\Omega\mathbf{W}^T)\)不影响我们的并行算法设计,这是非常好的。而问题\((2)\)总是能够加速问题的学习,当使用特定的优化程序如论文[5]和论文[6]一样,根据过去的研究这些方法有很快的收敛率,不管正则项是什么。
在问题\((2)\)中有许多损失函数可供使用,比如\(\text{hinge loss}\)、\(\epsilon-\text{insensitive loss}\)和\(\text{square loss}\),下面我们主要就采用这三种损失函数,后面我们会分别给出问题\((2)\)关于这三个损失函数的对偶形式。
3 并行多任务学习算法
3.1 FISTA迭代算法
下面我们就给出当使用不同的损失函数时问题\((2)\)的并行求解算法。因为我们的求解算法是基于FISTA迭代的,我们先来看FISTA迭代算法。
FISTA迭代算法[7]是一个加速梯度下降方法,用于求解一个类似于下面这种形式的复合凸目标函数(compositely convex objective function):
这里\(\bm{\theta}\)是指模型的参数集合,\(f(\bm{\theta})\)是凸的且它的梯度有\(\text{Lipschitz}\)连续性,凸函数\(g(\bm{\theta})\)有着简单的且易分解(并行)的结构,\(\mathcal{C}_{\bm{\theta}}\)是指\(\bm{\theta}\)的定义域。FISTA算法最新构建代理损失函数\(Q_l(\bm{\theta}, \hat{\bm{\theta}})\)如下:
这里\(\nabla_{\bm{\theta}}f(\hat{\bm{\theta}})\)表示\(f(\bm{\theta})\)在\(\bm{\theta}=\hat{\bm{\theta}}\)点的梯度,\(\mathcal{L}\)是\(f(\cdot)\)梯度的\(\text{Lipschitz}\)常量,接着我们优化关于\(\bm{\theta}\)的函数\(Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})\),约束为\(\bm{\theta} \in \mathcal{C}_{\bm{\theta}}\)。函数\(Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})\)关于\(\bm{\theta}\)的优化器由\(q_{\mathcal{L}}(\hat{\bm{\theta}})\)表示。
FISTA算法伪代码如下图所示:
可以看到第\(17\)步和\(18\)步在\(\bm{\theta}\)能够被划分为许多部分的情况下可以轻易并行。但目前的问题是如何并行化算法步骤\(11\)或\(13\)。
3.2 将目标函数转换为对偶问题
当使用\(\text{hinge},\epsilon\text{-intensive}\)和\(\text{squre}\)损失函数时,我们需要用\(\text{FISTA}\)算法优化其对偶问题。下面我们分别说明得到这三个损失函数对应目标函数的对偶问题,后面我们会在此基础上进行并行化。
3.2.1 Hinge Loss
(1)转为对偶形式 我们将Hinge Loss函数\(L_h(y^{'},y)=\text{max}(1-y^{'}y, 0)\)代入式(2)的优化问题,并将无约束优化转为有约束优化可得到:
这里\(\bm{\eta}=(\eta_{11}, ..., \eta_{Kn_K})^T\)。引入非负的Lagrange乘子\(\{\alpha_{ki}\}\)和\(\{\beta_{ki} \}\),我们可以得到问题\((5)\)的对偶形式如下:
这里我们说明一下矩阵\(\mathbf{P}\)的含义,设\(\sigma_{ij}\)是任务关系协方差矩阵\(\Sigma\)的第\((i, j)\)个元素,\(\mathbf{K}\)是一个\(n \times n\)的矩阵,它的第\((I_{ab}, I_{cd})\)个元素是\(\sigma_{ac}k(\bm{x}_{ab}, \bm{x}_{cd})\),这里\(I_{ki} =i+\sum_{l=1}^{k-1}n_l\)计算在所有任务的训练数据中的\(\bm{x}_{ki}\)的下标。\(\odot\)指逐元素乘积操作,这里有\(\mathbf{P}=\mathbf{K} \odot (\bm{y}\bm{y}^T)\)。这里我们定义函数\(k_{MT}(\cdot, \cdot)\)为\(k_{MT}(\bm{x}_{qi}, \bm{x}_{rj}) = \sigma_{qr}k(\bm{x}_{qi}, \bm{x}_{rj})\)
用来构造矩阵\(\mathbf{K}\)。很容易证明这是一个核函数。所以我们称\(k_{MT}(\cdot, \cdot)\)是一个多任务核函数,将\(\mathbf{K}\)称为多任务核矩阵。
3.2.2 \(\epsilon\) - Insensitive Loss
接下来我们讨论将\(\epsilon-\) insensitive loss函数
\(L_{\epsilon}(y,y^{'}) = \left\{ \begin{aligned} 0 \quad \text{若} |y - y^{'}| \leqslant \epsilon \\ |y - y^{'}| - \epsilon \quad \text{其他} \end{aligned} \right .\)
代入问题\((2)\)进行优化。我们再引入一些松弛变量,问题\((2)\)可被转化为:
这里\(\bm{\eta} = (\eta_{11},..., \eta_{Kn_K})^T\)和\(\bm{\tau} = (\tau_{11},..., \tau_{Kn_K})^T\)。
我们接下来引入Lagrange乘子\(\bm{ \alpha} = (\alpha_{11},...,\alpha_{Kn_K})^T\)和\(\bm{ \beta} = (\beta_{11},...,\beta_{Kn_K})^T\),进一步得到问题\((7)\)的对偶问题:
这里\(\mathbf{1}\)表示一个元素全为1的合适大小的向量或者矩阵,\(\mathbf{K}\)表示由等式\((8)\)的多任务核函数\(k_{MT}(\cdot, \cdot)\)构成的矩阵。这里\(\bm{y}=(y_{11}, ...,y_{Kn_K})^T\)。
3.2.3 Square loss:
我们将square loss代入问题\((2)\),得到以下优化问题:
引入Lagrange乘子\(\{\alpha_{ki}\}\),我们就可以得到问题\((9)\)的对偶形式:
这里\(\bm{\alpha}_k = (\alpha_{11},...,\alpha_{Kn_K})^T\)。这里\(\mathbf{Q} = \mathbf{K} + \frac{\lambda}{2}\mathbf{\Lambda}\),\(\mathbf{\Lambda}\)是一个对角矩阵,相应的数据点属于第\(k\)个任务时其对角元素为\(n_k\)。
注意,后面我们会发现三个损失函数对应的对偶形式都有着相似的形式而且和单任务对偶形式的主要不同点都在于线性不等式约束。也就是说,在单任务对偶形式中,只有一个涉及Lagrange乘子的线性不等式约束;但是在多任务环境下,有\(K\)个线性不等式约束,每个不等式都由一个任务的Lagrange乘子组成。有趣的是,这种差别决定了我们后面设计的并行算法。
3.3 将对偶问题的求解并行化
接下来我们需要展示应用FISTA算法并行化求解\((6)\),其他损失函数同理。我们定义\(\bm{\theta}=\bm{\alpha}\),\(\bm{\phi} = \hat{\bm{\alpha}}\),\(f(\bm{\alpha}) = \frac{1}{2\lambda}\bm{\alpha}^T\bm{P}\bm{\alpha}\),\(g(\bm{\alpha}) = \sum_{k=1}^K\sum_{i=1}^{n_k}\alpha_{ki}\),定义域\(\mathcal{C}_{\alpha} = \{\alpha | \sum_{i=1}^{n_k}\alpha_{ki} y_{ki} = 0(k=1,2,..,K, i=1,2,...,n_k, 0 \leqslant \alpha_{ki}\leqslant \frac{1}{n_k})\}\)。下面我们来看如何并行化算法步骤\(11\)或\(13\)。
\(f(\bm{\alpha})\)关于\(\bm{\alpha}\)的二阶导数\(\nabla^2 f(\bm{\alpha})\)是我们这里的\(\frac{1}{\lambda}\bm{P}\)。我们用\(||\cdot||\)表示矩阵的\(l_2\)范数,易得\(||\mathbf{P}||_2 \mathbf{I}_n-\mathbf{P}\)是一个半正定矩阵。所以\(f(\bm{\alpha})\)的最小\(\text{Lipschitz}\)常量是\(\frac{1}{\lambda}||\textbf{P}||_2\)(\(\mathcal{L} \geqslant {\frac{1}{\lambda}||\textbf{P}||_2}\))。当\(n\)非常大时,计算\(||\textbf{P}||_2\)非常耗时,我们下面会展示如何并行地计算它。
给定\(\mathcal{L}\),我们能够优化关于\(\bm{\alpha}\)的函数\(Q_{\mathcal{L}}(\bm{\alpha}, \hat{\bm{\alpha}})\),这也是步骤11或13要求解的(并行地)。特别地,步骤11或13要求解的优化问题可以被描述为:
该问题可以被分解为\(T\)个独立的子问题,第\(t\)个子问题为:
这里\(\alpha_k = (\alpha_{k1}, ..., \alpha_{kn_k} )^T\),\(a_{ki} = \mathcal{L}\hat{\bm{\alpha}}_{kj}+1-\frac{1}{\lambda} \hat{p}_{ki}\),\(\hat{p}_{ki}\)是\(\text{P}\hat{\bm{\alpha}}\)中与\(\bm{x}_{ki}\)对应的元素。\(c_k=0,\rho = 0, d_k = \frac{1}{n_k}\)。
问题\((12)\)是一个二次规划(quadratic programming, QP)问题,我们能够不用任何QP求解器,在\(O(n_k)\)的时间内用拉格朗日乘子法求解。正如问题\((11)\)所示,FISTA算法的每一轮迭代都需要计算\(\mathbf{P}\hat{\bm{\alpha}}\)以决定\(Q_{\mathcal{L}}(\bm{\alpha}, \hat{\bm{\alpha}})\)。如果我们直接解任务\((12)\),\(\alpha\)会完全和之前的估计不同,且计算\(\mathbf{P}\hat{\bm{\alpha}}\)会花费\(O(n^2)\),当\(n\)很大时计算量太大。所以这里我们希望采取SMO算法的思想,只更新部分的\(\alpha\)元素,这样计算\(\mathbf{P}\hat{\bm{\alpha}}\)的时间复杂度减少到\(O(n)\)。(因为我们只需要关心变化的元素)
参考
- [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
- [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
- [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
- [4] Zhang Y. Parallel multi-task learning[C]//2015 IEEE International Conference on Data Mining. IEEE, 2015: 629-638.
- [5] Zhang Y, Yeung D Y. A convex formulation for learning task relationships in multi-task learning[J]. arXiv preprint arXiv:1203.3536, 2012.
- [6] Zhang Y, Yeung D Y. A regularization approach to learning task relationships in multitask learning[J]. ACM Transactions on Knowledge Discovery from Data (TKDD), 2014, 8(3): 1-31.
- [7] A. Beck and M. Teboulle, “A fast iterative shrinkagethresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, 2009
- [8] 杨强等. 迁移学习[M].机械工业出版社, 2020.