Navon A., Achituve I., Maron H., Chechik G. and Fetaya E. Auxiliary learning by implicit differentiation. ICLR, 2021.
概
通过 implicit differentiation 优化一些敏感的参数.
AuxiLearn
-
在实际的训练中, 我们常常会通过一些额外的任务来帮助更好的训练:
ℓmain+∑kϕkℓk,
其中 ϕk≥0 是第 k 个额外任务 ℓk 的系数.
-
比较常见的做法是通过 grid search 来选择合适 ϕk. 当额外任务的数量有效的时候尚可, 但是始终缺乏扩展性. 一种理想的方式通过某种可学习的方式设定.
-
但是很显然, 如果利用梯度下降学习 ϕk 并通过 clip 保证 ϕk≥0, 一定会导致 ϕk≡0 这一平凡解.
问题设定
-
现在让我们来设定一个更加一般的问题:
LT(W,ϕ)=ℓmain(W;Dtrain)+h(W;ϕ,Dtrain),LA(W)=ℓmain(W;Daux).
其中 W∈Rn 是模型中的基本参数, ϕ∈Rm 是一些其它的超参数, 然后 Dtrain,Daux 表示训练集和额外的集合 (比如验证集).
-
不考虑 mini-batch, 合理的训练流程应该是:
Wt+1←argminWLT(W;ϕt)ϕt+1←argminϕLA(Wt+1(ϕ))
如此重复. 就能够避免 ϕ 的平凡解.
-
当然, 如果每一次都严格按照两阶段计算, 计算量是相当庞大的 (比 grid search 也是不遑多让). 本文所提出来的 AuxiLearn 的改进就是提出了一种近似方法. 它的理论基础是 Implicit Function Theorem (IFT).
-
为了能够通过梯度下降的方式更新 ϕ, 我们首先需要推导出它的梯度:
∇ϕLA=∇WLA1×n⋅∇ϕW∗n×m.
显然, 其中 ∇WLA 是好计算的, 问题在于 ∇ϕW∗ 的估计.
-
为了推导 ∇ϕW∗, 我们需要用到 IFT. IFT 告诉我们, 对于一个连续可微映射 F(x,y):Rm×Rn→Rn. 在一定条件下, 如果存在 p∈Rm,q∈Rn 使得
F(p,q)=0,
则存在一个映射 Φ:Rm→Rn 使得
F(x,Φ(x))=0
在某个集合上均成立.
-
现在, 让我们来推导. 首先 W∗ 是 LT 的最优点, 当
∇WLT(W∗,ϕ)=0,
令 F 为 ∇WLT, x 为 ϕ, y 为 W, 套用 IFT, 可知, 存在 W∗(ϕ) 使得
∇WLT(W∗(ϕ),ϕ)=0,
在包含 ϕ 的某个子集上都成立. 于是, 我们有
∇ϕ∇WLT(W∗(ϕ),ϕ)=0⇒∇2WLT⋅∇ϕW∗+∇ϕ∇WLT=0⇒∇ϕW∗=−(∇2WLT)−1⋅∇ϕ∇WLT.
-
因此, ϕ 处的梯度为:
∇ϕLA=−∇WLA1×n⋅(∇2WLT)−1n×n)⋅∇ϕ∇WLTn×m.
-
现在的问题是, 怎么估计 (∇2WLT)−1, 作者采用 Neumann series. Neumann series 告诉我们:
(I−X)−1=∑tXt⇒X−1=(I−(I−X))−1=∑t(I−X)t.
于是便得到了本文 AuxiLearn 算法 (算法 2 其实就是 Neumann series 的前 J 项):

理解两阶段的训练
-
让我们通过一个最简单的例子来理解:
LT(W,ϕ)=ℓmain(W;Dtrain)+ϕ⋅ℓaux(W;Dtrain).
-
容易发现:
dLAdϕ=−∇WLA⋅(∇2WLT)−1⋅∇ϕ∇WLT=−∇WLA⋅(∇2WLT)−1⋅∇ϕ(∇Wℓmain(Dtrain)+ϕ∇Wℓaux)=−∇WLA⋅(∇2WLT)−1⋅∇TWℓaux(Dtrain)=−∇WLmain(Daux)⋅(∇2WLT)−1⋅∇TWℓaux(Dtrain).
-
可以发现, ϕ 逐渐增大的前提是:
∇WLmain(Daux)⋅(∇2WLT)−1⋅∇TWℓaux(Dtrain)>0,
即当主任务在 aux 集合上的更新方向和辅任务在训练集上在 ∇2WL−1T 意义上方向一致.
代码
[official-code]
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
2023-10-11 Graph Laplacian for Semi-Supervised Learning
2023-10-11 Weighted Nonlocal Laplacian on Interpolation from Sparse Data
2022-10-11 Visualizing Deep Neural Network Decisions: Prediction Difference Analysis
2022-10-11 GNNExplainer: Generating Explanations for Graph Neural Networks