Multi-Task Learning as Multi-Objective Optimization

Sener O. and Koltun V. Multi-task learning as multi-objective optimization. In Advances in Neural Information Processing Systems (NIPS), 2018.

本文提出的 MGDA-UB 用于同时处理\(T\)个任务:

\[\tag{1} \min_{\theta^{sh}, \: \theta^1, \cdots, \theta^T} \mathbf{L}(\theta^{sh}, \theta^1, \cdots, \theta^T) := (\hat{\mathcal{L}}^1(\theta^{sh}, \theta^1), \cdots, \hat{\mathcal{L}}^T(\theta^{sh}, \theta^T))^{\mathbf{T}}. \]

其中

\[\hat{\mathcal{L}}^t(\theta^{sh}, \theta^t) := \frac{1}{N} \sum_i \mathcal{L}(f^t(x_i; \theta^{sh}, \theta^t), y_i^t), \: t =1,2,\cdots, T \]

共享参数\(\theta^{sh}\)同时具备独立的参数\(\theta^t\).

在了解 MGDA-UB 之前, 务必先了解 MGDA.

主要内容

Pareto 最优

一种通常的同时解决多个任务的方法是施加不同的权重:

\[\min_{\theta} \sum_{t=1}^T c^t \hat{\mathcal{L}}^t. \]

但对于\(\theta\)\(\bar{\theta}\), 倘若满足

\[\hat{\mathcal{L}}^{t_1}(\theta) < \hat{\mathcal{L}}^{t_1}(\bar{\theta}) \\ \hat{\mathcal{L}}^{t_2}(\theta) > \hat{\mathcal{L}}^{t_2}(\bar{\theta}), \]

该如何判断\(\theta, \bar{\theta}\)的优劣呢? 这个可能得根据一些其它的指标来判断了. MGDA所希望的是优化得到多个模型的 Pareto 最优解\(\theta^*\), 即不存在 \(\theta\) 满足

\[\hat{\mathcal{L}}^{t}(\theta^*) \ge \hat{\mathcal{L}}^t(\theta) \quad \forall t \\ \hat{\mathcal{L}}^{t}(\theta^*) > \hat{\mathcal{L}}^t(\theta) \quad \exist t. \]

显然在 Pareto 最优的意义下, 我们找不到更好的解完全优于 \(\theta^*\).

KKT 条件

我们可以得到 (1) 的 KKT 条件:

\[\sum_{i=1}^T \alpha_i^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh}, \theta^t) = 0, \\ \nabla_{\theta^t} \hat{\mathcal{L}}^t (\theta^{sh}, \theta^t) = 0, \: \forall t=1,2\cdots, T, \\ \sum_{i=1}^T \alpha_i = 1, \: \alpha_i \ge 0, \: \forall t = 1,2,\cdots, T. \]

一个不严格的推导:


假设在 Pareto 最优解 \(\theta^*\) 下:

\[\hat{\mathcal{L}}^t(\theta^*) = C_t, \: t=1,2,\cdots, T. \]

\[\begin{array}{rl} \min_{\theta} & \hat{\mathcal{L}}^1 \\ \mathrm{s.t.} & \hat{\mathcal{L}}^t \le C_t, \: t=2,3,\cdots, T. \end{array} \]

则引入拉格朗日乘子 \(\lambda_t \ge 0\)可得

\[L(\theta; \lambda) = \hat{\mathcal{L}}^1 + \sum_{t=2}^T \lambda_t (\hat{\mathcal{L}}^t - C_t), \]

通过

\[\nabla_{\theta} L = 0, \\ \alpha_t = \frac{\lambda_t}{\sum_t \lambda_t}, \: \lambda_1 = 1 \]

便可得到上方的关系.


MDGA-UB

因为\(\theta^{sh}\)\(\theta^t\)之间是独立的, 所以对于后者我们可以利用普通的梯度即可, 对于前者, 本文利用 MGDA-UB 来选择合适的下降方向.

定义

\[V = \Bigg\{ \sum_{t=1}^T \alpha^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t : \sum_t \alpha_t = 1, \alpha_t \ge 0 \: \forall t \Bigg\}. \]

MGDA 中指出, 若

\[\tag{2} v := \mathop{\text{argmin}}_{v \in V} \|v\|_2^2, \]

\(-v\)\(\theta^{sh}\)的一个可行的下降方向.

所以, 现在的问题是如何求解\(v\)呢?

特殊的双重任务

对于\(T=2\)的情形, (2) 等价于

\[\min_{\gamma \in [0, 1]} \: \| \gamma \nabla_{\theta^{sh}} \hat{\mathcal{L}}^1(\theta^{sh}, \theta^1) +(1 - \gamma) \nabla_{\theta^{sh}} \hat{\mathcal{L}}^1(\theta^{sh}, \theta^1) \|_2^2, \]

其解可以显示表达为 (证明比较简单, 这里不写了)

\[\hat{\gamma} = \max(\min([\frac{(u_1 - u_2)^T u_2}{\|u_1 - u_2\|_2^2}], 1), 0). \]

其中 \(u_t := \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t(\theta^{sh}, \theta^t)\).

一般的多重任务

\[U = [u_1, u_2, \cdots, u_T], \: u_t := \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t(\theta^{sh}, \theta^t). \]

则 (2) 等价于

\[\begin{array}{rl} \min_{\alpha} & \bm{\alpha}^T U^TU \bm{\alpha} \\ \mathrm{s.t.} & -\bm{\alpha} \preceq 0, \bm{1}^T \bm{\alpha} = 1. \end{array} \]

作者采用 Frank-Wolfe 算法来近似求解:

更高效

上述的算法在时间复杂度上有缺陷, 因为这要求我们对每一个任务 \(\hat{\mathcal{L}}^t\)分别回传梯度, 这会比较耗时. 但对于 Encoder-Decoder 类型的结构, 时间能够进一步减少:

\[f^t(x;\theta^{sh}, \theta^t) = (f^t(\cdot;\theta^t) \circ g(\cdot;\theta^{sh}))(x) = f(g(x; \theta^{sh}); \theta^t). \]

\(Z = (z_1, z_2, \cdots, z_N)\)为encoder \(g\)所输出的隐变量, 则

\[\|\sum_{t=1}^T \alpha^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh}, \theta^t)\|_2^2 \le \|\frac{\partial Z}{\partial \theta^{sh}}\|_2^2 \cdot \|\sum_{t=1}^T \alpha^t \nabla_{Z} \hat{\mathcal{L}}^t (\theta^{sh}, \theta^t)\|_2^2, \]

故我们通过

\[\|\sum_{t=1}^T \alpha^t \nabla_{Z} \hat{\mathcal{L}}^t (\theta^{sh}, \theta^t)\|_2^2 \]

来找到\(\alpha\), 由于这是一个上界, 所找到的下降方向也不会太差. 此外, 由于隐变量\(Z\)的维度一般远远小于网络的共享参数, 且回传梯度只需到中间结点, 故计算成本会省很多.

注: 作者实验发现, 优化上界比直接优化本身反而效果更好, 这可能和SGD以及Frank-Wolfe本身是近似算法有关系 (\(Z\)的维度小, 解的可能更好).

Frank-Wolfe

  1. 求解线性规划问题

\[\begin{array}{rl} \min_x & f(x) \\ \mathrm{s.t.} & Ax \le b \\ & Mx = m. \end{array} \]

  1. 每一步求解一阶近似

\[\begin{array}{rl} \min_{x_t} & f(x_{t-1}) + \nabla^{\mathbf{T}} f(x_{t-1}) (x_t - x_{t-1}) \Leftrightarrow \nabla^T f(x_{t-1}) x_t \\ \mathrm{s.t.} & Ax_t \le b \\ & Mx_t = m, \end{array} \]

得到该问题的精确解 \(y_{t}\), 以及下降方向 \(\Delta_t := y_{t} - x_{t - 1}\).
3. 计算最速下降步长

\[\gamma_t = \min_{\gamma \in [0, 1]} f(x_{t-1} + \gamma_t \cdot \Delta_t), \]

  1. 更新 \(x_t = x_{t-1} + \gamma_t \Delta_t = \gamma_t y_t + (1 - \gamma_t)x_{t-1}\).

上文中的算法2可以据此一一对照得到.

实现细节

从代码中发现了一些细节:

  1. \(\gamma\) 取值 \([0.001, 0.999]\);
  2. 梯度用于算法二计算\(\alpha\)有可能会通过标准化;
  3. 得到\(\alpha\)之后作者scale每个损失, 再回传梯度, 也就是说即便上面用了标准化, 真正的梯度用的时候还是没标准化的.

代码

原文代码

posted @ 2022-05-08 20:29  馒头and花卷  阅读(1231)  评论(11编辑  收藏  举报