[1] Duchi J. and Singer Y. Efficient Learning using Forward-Backward Splitting. NeurIPS, 2009.
[2] Xiao L. Dual Averaging Method for Regularized Stochastic Learning and Online Optimization. NeurIPS, 2009.
[3] McMahan H. B. and Streeter M. Adaptive bound optimization for online convex optimization. Colt, 2010.
[4] McMahan H. B. Follow-the-regularized-leader and mirror descent: Equivalence theorems and l1 regularization. AISTATS, 2011.
[5] McMahan H. B., et al. Ad click prediction: a view from the trenches. KDD, 2013.
概
这里介绍一些结合正则化项的随机优化方法. 需要注意的是, 上面的部分链接导向的是仅正文部分, 对应的完整证明需要在 JMLR 上所发表的完整版上找到. 不过我还只是给了简短的链接, 因为阅读起来更加容易.
符号说明
- w, 参数向量;
- ft(w) 损失函数;
- r(w) 正则化项 (e.g., ℓ1,ℓ2);
- gt=∇wft(w), 梯度;
- ∂r(w), 若 r 在 w 可导, 则表示梯度, 否则表示次梯度
- ⊙, 哈达玛积;
- [x]+=max(x,0).
Motivation
-
一般来说, 模型的训练目标是
minwf(w),
或者随机优化/在线学习的情况:
minw1,…,wtT∑t=1ft(wt).
-
不过, 我们通常会对参数 w 加以限制, 从而优化如下的损失
f(w)+λr(w),
这里 r(w) 是正则化项, 可以是比如 ∥w∥p=(|w1|p+⋯+|wd|p)1/p.
-
这里, 我们介绍的引入这些正则化项的方式, 并不是仅仅把它作为损失然后计算梯度, 而是通过修改优化器的优化规则来进行. 正则化项实际上一种人为的先验, 优化的目标自然是每一次更新尽可能地满足这些先验, 我们可以以两阶段的方式来理解这里介绍的方法 (虽然, 有些方式的形式是一步的):
- 通过传统的梯度方法得到 t+1 步的参数估计 ^wt+1;
- 将 ^wt+1 投影到由 r 所刻画的区域中去, 得到真正的参数 wt+1.
-
说实话, 我不清楚这种方式和简单把 r 作为损失的一部分孰优孰劣, 但是比较容易理解的是, 比如 r 为 ℓ1 时, 我们其实时希望这个正则化项带来参数的稀疏化. 但是由于每次更新的数值误差, 其实很难指望简单的随机梯度下降是能够满足这个性质的. 但是从下面我们可以看到, 通过投影的方式引入, 实际上会导致显式的截断操作, 从而保证稀疏化.
FOBOS (Forward-Backward Splitting)
-
这个是由 [1] 所提出的, 同时请不要在意缩写.
-
FOBOS 的更新规则如下:
^wt+1=wt−ηtgt,wt+1=argminw{12∥w−^wt+1∥2+λtr(w)}.
-
容易证明上述问题等价于找到一个 wt+1 满足:
wt+1+λt∂r(wt+1)=^wt+1.
-
ℓ1 regularization:
∂rℓ1(w)=⎧⎪⎨⎪⎩1w>0,−1w<1,[−1,1]w=0
wt+1=^wt+1−λt⇔∂r(wt+1)=1⇔wt+1>0⇔^wt+1>λt,wt+1=^wt+1+λt⇔∂r(wt+1)=−1⇔wt+1<0⇔^wt+1<−λt,wt+1=0⇔∂r(wt+1)=^wt+1λt⇔|^wt+1|<λt.
wt+1=sgn(^wt+1)⊙[|wt+1|−λt]+
-
ℓ22 regularization: r(w)=12∥w∥22
-
ℓ2 regularization: r(w)=∥w∥2
-
ℓ∞ regularization: r(w)=∥w∥∞.
- 这个问题不好通过次梯度分析, 一般的做法是利用 ℓ∞ 的对偶范数为 ℓ1 范数来求解 (see here).
-
当然了, 类似的方法可以拓展到混合范数.
RDA (Regularized Dual Averaging)
-
这篇文中 [2] 指出, FOBOS 通常会采取一种衰减的 λt, 所以理论上越到后面, 稀疏性会越差 (采取 ℓ1 正则), 这篇文章就是探讨如何不采用衰减的 λt 从而尽可能地保持稀疏性.
-
RDA 的更新方式如下:
wt+1=argminw{⟨¯gt,w⟩+r(w)+βtth(w)},
其中 h(w) 是特别引入保证收敛的强凸函数 (在 r(w) 不是强凸的时候有用), {βt} 是一串非降的序列.
需要注意的是, 不同于 FOBOS, r(t) 内部包含了系数, 如 λ∥w∥1, 不过系数是固定的, 因此我们可以希望所得的参数是满足某些性质的.
-
另外,
¯gt=t−1t¯gt−1+1tgt,gt∈∂ft(wt),
是梯度的滑动平均, 我们在最后面会用到它的另一种表示:
¯gt=tg1:t:=t∑s=1gs.
-
关于范数 ∥⋅∥ 强凸: h 关于范数 ∥⋅∥ 是强凸的, 若存在 σ>0 使得
h(αw+(1−α)u)≤αh(w)+(1−α)(u)−σ2α(1−α)∥w−u∥2.
-
当 r(w) 不满足强凸的时候, 此时需要强凸的 h 进行调和, 同时需要令
βt=γ√t,
以获得 O(1/√T) 的收敛率.
一个例子是 ℓ1-regularizaton: r(w)=λ∥w∥1, 此时我们可以令
h(w)=12∥w∥22+ρ∥w∥1,
此时更新方程的解为
wt+1=−√tγ(¯gt−λRDAtsgn(¯gt))⊙[|¯gt|−λRDAt]+,
其中
λRDAt=λ+ρ/√t.
我们可以注意到, λRDAt≥λ, 所以这保证了非常一致的稀疏性.
-
而当 r(w) 本身是强凸的时候, 比如 r(w)=λ∥w∥1+(σ/2)∥w∥22, 此时我们不必添加额外的 h, 此时可以得到:
wt+1=−1σ(¯gt−λsgn(¯g))⊙[|¯gt|−λ]+.
FTRL-Proximal (Follow The Regularized Leader)
-
[3, 4] 都对这部分进行的阐述, 不过这里主要根据 [4] 进行总结. 我认为最重要的部分就是将 FOBOS 和 RDA 统一起来了. 至于如何统一起来, 让我们一点点看.
-
让我们首先写出, RDA 的一个等价形式:
wt+1=argminw{⟨¯gt,w⟩+r(w)+βtth(w)}=argminw{⟨t¯gt,w⟩+t⋅r(w)+h(w)}=argminw{⟨g1:t,w⟩(1)+t⋅r(w)(2)+h(w)(3)},
其中 g1:t:=∑ts=1gs.
-
(1) 保证 wt+1 和整体的(负)梯度是一致的, (2) 保证正则化项, 当其为 ℓ1 是就是保证足够的稀疏性, 而且可以发现, 对于 RDA 而言, 它的强度随着迭代次数的增加而增加. 实际上容易理解, 当我们去掉 t, 那么 r(w) 在整个损失中的地位是为逐步降低的 (因为 g1:t 正常来说是逐步增加的).
-
更有趣的是, 我们可以将 FOBOS 也写成这种形式, 从而更好的理解为什么 RDA 在引入稀疏化的能力上比 FOBOS 强.
-
我们有:
^wt+1=wt−ηtgt,wt+1=argminw{12∥w−^wt+1∥2+λtr(w)}=argminw{12∥w−wt+ηtgt∥2+λtr(w)}=argminw{12∥w−wt∥22+ηt⟨gt,w⟩+λtr(w)}=argminw{⟨gt,w⟩+λtηtr(w)+12ηt∥w−wt∥22}.
-
当我们把 λt/ηt 融入 r(w) 的时候, FOBOS 的更新公式实际上就和 RDA 差在了
⟨g∗,w⟩.
-
接下来我们先统一格式, 再证明一个更加酷的结论. 定义如下的更新方式:
^wt+1=argminw⟨gt,w⟩+r(w)+ζ1:t(w,^wt),(A.1)
其中
ζ1:t(x,y)=t∑s=1ζs(x,y)=t∑s=1[Ψs(x)−Ψs(y)+⟨∇Ψs(y),x−y⟩],
其中 Ψs(x):=ψs(x−^ws), ψs 是强凸的且最小值在 0 处取到.
这个项为 Bergman divergence.
-
我们证明 (A.1) 和下列的 (A.2) 所得到的参数更新是一致的 (如果起点都是 0): 存在 ϕk∈∂r(wk+1),k=1,2…,t 使得
g1:k+ϕ1:k−1+∇Ψ1:k(wk+1)+ϕk=0,
从而
wt+1=argminw⟨g1:t+ϕ1:t,w⟩+r(w)+Ψ1:t(w).(A.2)
proof:
-
首先 w1=^w1 是必然的. 我们用归纳法证明. 假设 wt=^wt.
-
在此之前我们先证明一个引理:
-
引理: 当 h 为强凸函数 (且存在一阶导数), g 为凸函数的时候, 存在唯一的配对 (x∗,ϕ) 满足 ϕ∈∂g(x∗) 且使得:
x∗=argminxh(x)+⟨ϕ,x⟩
成立. 同时 x∗ 是 h(x)+g(x) 的唯一最小值点.
-
实际上, 由于 h 为强凸, g 为凸函数, 所以 h+g 为强凸, 故最小值点 x∗ 是唯一的. 同时满足存在 ϕ∈∂g(x∗) 使得
∇h(x∗)+ϕ=0.
因此容易发现 x∗ 是 h(x)+⟨ϕ,x⟩ 的最小值点.
-
在 wt=^wt 的基础上, 我们可以一步一步推得
^wt+1=argminw⟨gt,w⟩+r(w)+ζ1:t(w,^wt)=argminw⟨gt,w⟩+r(w)+Ψ1:t(w)−Ψ1:t(^wt)−⟨∇Ψ1:t(^wt),w−^wt⟩=argminw⟨gt,w⟩+r(w)+Ψ1:t(w)−⟨∇Ψ1:t(^wt),w−^wt⟩=argminw⟨gt,w⟩+r(w)+Ψ1:t(w)−⟨∇Ψ1:t−1(^wt),w−^wt⟩=argminw⟨gt,w⟩+r(w)+Ψ1:t(w)−⟨∇Ψ1:t−1(wt),w−wt⟩=argminw⟨gt,w⟩+r(w)+Ψ1:t(w)+⟨g1:t−1+ϕ1:t−1,w−wt⟩.
-
再次运用上面的结论, 我们可以找到 ϕ′t∈∂Ψ(^wt+1), 满足:
^wt+1=argminw⟨gt,w⟩+⟨ϕ′t,w⟩+Ψ1:t(w)+⟨g1:t−1+ϕ1:t−1,w⟩.
-
由于 (^wt+1,ϕ′t) 的唯一性, 便有:
(wt+1,ϕt)=(^wt+1,ϕ′t).
-
于是, 我们证明了两个过程的等价性. 但是这里需要声明的是, (A.2) 由于 ϕ 的存在, 并不是一个很有用的更新方式, 但是它可以给我们带来一些启发.
-
比如, 当取 ψs(w)=σ2s2∥w∥22
ζ1:t(w,^wt)⇔12∥σ1:t(w−^wt)∥22
可以看出是 FOBOS 的更新方式 (设置合适的 σ), 此时等价于:
wt+1=argminw⟨g1:t+ϕ1:t,w⟩+r(w)+12t∑s=1∥σs(w−ws)∥22.
-
所以, FOBOS 对比 RDA 实际上就是
- 多了
⟨ϕ1:t,w⟩
这一项. 这一项是对正则化的一个估计. RDA 稀疏化的效果好的原因, 实际上是因为它采取的不是估计, 而是通过增强 r(w).
- RDA 的 h(w) 通常是 12∥w∥22, 是 origin-centered 的
FOBOS, RDA, FTRL-Proximal 的统一表示
-
FOBOS:
wt+1=argminw⟨g1:t+ϕ1:t,w⟩+r(w)+12t∑s=1∥σs(w−ws)∥22.
-
RDA:
wt+1=argminw⟨g1:t+0,w⟩+t⋅r(w)+12t∑s=1∥σs(w−0)∥22.
-
FTRL-Proximal:
wt+1=argminw⟨g1:t+0,w⟩+t⋅r(w)+12t∑s=1∥σs(w−ws)∥22.
-
注意到, FTRL-Proximal 实际上是 FOBOS 和 RDA 的结合, 既有 FOBOS 的 xs 中心化, 又有 RDA 的稀疏性.
-
我们是不需要记忆所以的 ws 的, 注意到: FTRL-Proximal 等价于
wt+1=argminw⟨g1:t−t∑s=1σ2sws,w⟩+t⋅r(w)+12t∑s=1σ2s∥w∥22.
-
令 ηt=1/(∑ts=1σ2s), 以及
zt:=g1:t−t∑s=1σ2sws.
-
我们有
wt+1=argminw⟨zt,w⟩+t⋅r(w)+12ηt∥w∥22.
且
zt+1=zt+gt+1−(1ηt+1−1ηt)wt+1.
所以, 实际上我们所需要保存的量很少.
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
2022-07-16 Recommendations as Treatments: Debiasing Learning and Evaluation