深入分析:近端梯度下降法、交替方向乘子法、牛顿法

写在前面

本文主要围绕近端梯度下降法(Proximal Gradient Descent)、交替方向乘子法(Alternating Direction Method of Multipliers)、牛顿法来结合实际的案例进行推导分析,主打一个面向对象。

近端梯度下降法

**PGD (Proximal Gradient Descent) **,称为近端梯度优化法,近端指的是局部区域,在损失函数曲线上的一个泰勒展开点的近端或附近。近端梯度优化则是损失函数曲线上的一个点附近进行泰勒展开,通过执行梯度优化寻找局部最优解。

​ 为什么要提出PGD?与L1范数相关的稀疏问题求解中,L1范数不是处处可导(在零点不可导),无法使用梯度下降法。因此不难发现,其主要用于解决目标函数中存在可微和不可微函数的情况,如sgn函数。

​ 在近端梯度求解时,会遇到绝对值求导的问题,绝对值求导结果为符号函数Sgn(x),这个过程需要分情况讨论,因此会形成软阈值算子。在下面的例子中,x即为关于b的软阈值。

minx12xb22+λx1(xb)+λsgn(x)=0x={b+λ,b<λ0,|b|λbλ,b>λx=Sλ(b)

​ 因此在含L1范数的稀疏编码关于近端梯度下降算法的求解问题中,面临迭代软阈值优化分析,故这类问题也称为迭代软阈值算法(ISTA,Iterative Shrinkage Thresholding Algorithm)。

​ 在这里,将近端梯度算法有关的算法做一个归类,针对问题:x=argminxg(x)+h(x),如果函数g(x)是可微的凸函数,h(x)是不可微的凸函数,那么可以根据h(x)将近端梯度算法表示为以下几种:

  • 如果h(x)=0,则近端梯度算法退化为一般的梯度下降算法
  • 如果h(x)=IC(x),则近端梯度算法称为投影梯度下降算法,其中示性函数IC(x)={0,xC,xC
  • 如果h(x)=λx1,则近端梯度算法称为迭代软阈值算法。

标准Lasso问题(PGD)

​ 针对问题:minxg(x)+h(x)=minx12||Axb||22+λ||x||1,我们需要令其能够转化为(xb)2的形式,因此,我们可以选择在x0处泰勒展开(令2g(x0)=1t),则有:

g(x)g(x0)+g(x0)(xx0)+122g(x0)(xx0)2=g(x0)+g(x0)(xx0)+12t(xx0)2

​ 那么,Lasso问题等价为:

minx g(x)+h(x)min g(x0)+g(x0)(xx0)+12t(xx0)2+h(x)=min12t[xx0+tg(x0)]2+h(x)=minx12t||x(x0tg(x0)||22+h(x)=min12t||xz||22+h(x)

​ 至此,我们可以得到z=x0tg(x0),即g(x)梯度下降的形式,此时如果代入h(x)=λx1,我们就不难发现这个式子和开篇的类似,因此,我们可以得到Lasso问题的解为x=Sλ(z)

​ 近端算子则可以表示为:proxt,h()(z)=argmin12||xz||22+th(x)

​ 因此,近端梯度下降的迭代过程可以表示为如下:先对g(x)进行梯度下降求解z(k+1)=x(k)tg(x(k)),再代入x(k+1)=proxt,h()(z(k+1))=Sλ(z(k+1))

标准Lasso问题(ISTA)

Lasso (Least Absolute Shrinkage and Selection Operatior),最小绝对收缩选择算子,本质是给解向量增加L1范数约束,使向量的元素尽可能稀疏。

​ 给定目标函数如下:

minβ,α12yXβ22+λα1,s.t.βα=0

​ 引入中间变量w,如下:

L(α,β,ρ)=12yXβ22+λα1++ρ2βα+w22ρ2w22

​ 下面分别对L关于α,βρ项求极值点分析。

​ 1、首先,对式中与β有关项进行偏导分析,详细过程如下(懒得描绘,直接看推导过程吧):

minβ12yXβ22+ρ2βα+w22=minβ12βTXTXβyTXβ+ρ2βTβρ(αw)Tβl1=yTXβl1β=XTyl2=12βTXTXβl2β=XXTβl3=ρ2βTβl3β=ρβl4=ρ(αw)Tβl4β=ρ(αw)XXTβXTy+ρβρ(αw)=0(XXT+ρI)βXTyρ(αw)=0β(l+1)=(XXT+ρI)1[XTy+ρ(α(l)w(l))]

​ 2、其次,对式中与α有关项进行偏导分析,详细过程如下:

minαλα1+ρ2βα+w22=minαλα1+ρ2(2αTβ+αTα2αTw)λα1ρβ+ραρw=0λρα1+α=β+w{α+λρ=β+w,α>0α[β+wλρ,β+w+λρ]αλρ=β+w,α<0α(l+1)=Sλρ(β(l+1)+w(l))

​ 3、最后,更新w项:w(l+1)=w(l)+β(l)α(l)

混合Lasso问题(ISTA)

​ 这个案例选自国防科大ISAR高分辨成像的1篇文章ADMM-Net,其主要引入了卷积算子来解决传统LASSO-成像问题中忽略了弱散射中心与强散射中心的关系导致的弱散射点成像不显著问题。其给定的目标函数如下:

minX12YAXF2+λ1CX+ϵX1

​ 上式中,C为卷积核,为二维卷积,ϵ为任意极小值,为矩阵哈达玛积。下面通过引入中间变量Z=X来解耦合卷积过程的两项表达式,考虑中间变量后的目标函数如下:

minX,Z,B12YAXF2+λ1CX+ϵZ1,s.t.XZ=0

​ 下面,我们将上式改写为增广拉格朗日方程的形式:

L(X,Z,B)=12YAXF2+λ1CX+ϵZ1+B,XZ+μ2XZF2

我们对上述目标函数L(X,Z,B)关于变量​X,​Z和​B分别求偏导,可以得到如下表达式:

​ 1、首先,对关于X的项更新:

AH(YAX)+B+μ(XZ)=0(AHA+μI)X=AHYB+μZX=(AHA+μI)1(AHYB+μZ)

​ 2、再次,对关于Z的项更新:

λZ1CX+ϵ1μ(XZ)B=0λ1CX+ϵ1μZ+Z=XBμZ=Sλμ(CX+ϵ)1(XBμ)

​ 3、最后,对关于B项的更新:

B(l+1)=B(l)+μ(Xl+1Zl+1)

交替方向乘子法

交替方向乘子法的主要思想为将大问题拆解为若干子问题进行迭代求解。

原子范数软阈值AST推导

单快拍

​ 在范数对偶问题证明中,有噪声版本下的单快拍原子范数软阈值问题可以表示为:

minimizet,u,x,Z12xy22+τ2(t+u1)subject toZ=[T(u)xxt]Z0.

​ 下面给出具体的变量迭代过程:

​ 1、首先需要将上述有约束条件的原问题表述为增广拉格朗日方程形式,如下所示:

Lρ(t,u,x,Z,Λ)=12xy22+τ2(t+u1)+Λ,Z[T(u)xxt]+ρ2Z[T(u)xxt]F2

​ 其中,Λl=[Λ0lλ1lλ1lΛn+1,n+1l]Zl=[Z0lz1lz1lZn+1,n+1l]

​ 2、下面依次对变量x,t,u依次迭代更新:

​ 2.1 首先提取关于x项的表达式,12xy22+Λ,Z[T(u)xxt]+ρ2Z[T(u)xxt]F2

​ 其偏导为2λ1l+2ρ(xz1l)+xy=0那么有xl+1=y+2λ1l+2ρz1l1+2ρ.

​ 2.2 其次提取关于t项的表达式,τ2(t+u1)+Λ,Z[T(u)xxt]+ρ2Z[T(u)xxt]F2

​ 其偏导为τ2Λn+1,n+1l+ρtρZn+1,n+1l=0,那么有tl+1=1ρ(ρZn+1,n+1l+Λn+1,n+1lτ/2).

​ 2.3 其次提取关于u项的表达式,τ2(t+u1)+Λ,Z[T(u)xxt]+ρ2Z[T(u)xxt]F2

​ 其偏导为τ2e1Λ0l+ρ(T(u)Z0l)=0,那么有ul+1=W(T(Z0l+Λ0l/ρ)τ2ρe1),对角矩阵W满足关系Wii={1ni=112(ni+1)i>1T()表示生成共轭转置向量.

​ 2.4 其次提取关于Z项的表达式,Λ,Z[T(u)xxt]+ρ2Z[T(u)xxt]F2

​ 其可进步表示为ρ2Z[T(u)xxt]+ρ1ΛF2+Const,当且仅当Z=[T(u)xxt]+ρ1Λ时有最小值.

​ 因此Zl+1=[T(ul+1)xl+1(xl+1)tl+1]+ρ1Λl

​ 2.5 最后,更新拉格朗日乘子项Λl+1=Λl+ρ(Zl+1[T(ul+1)xl+1(xl+1)tl+1])

多快拍

​ 在范数对偶问题证明中,有噪声版本下的多快拍原子范数软阈值问题可以表示为:

[X,u]=argminX,W,u,Θ[Tr(W)+Tr(T(u))]+12||YX||F2,s.t.Θ=[T(u)XXHW]0

下面给出具体的变量迭代过程:

​ 1、首先需要将上述有约束条件的原问题表述为增广拉格朗日方程形式,如下所示:

L=argminτ2[Tr(W)+Tr(T(u))]+12||YX||F2+Λ,Θ[T(u)XXHW]+ρ2Θ[T(u)XXHW]F2

​ 2、下面需要依次对变量X,W,u,Θ,Λ等参量分别求极值点来更新每个子问题的最优解;在正式更新前,需要展开以下几个参量表示,以更好地帮助证明推导。(下面中M,L分别表示阵元数目和快拍数目)

Θ=[ΘT(u)ΘX(ΘX)HΘW],Λ=[ΛT(u)ΛX(ΛX)HΛW]C(M+L)×(M+L)

​ 上式中,ΘW,ΛWCL×LΘT(u),ΛT(u)CM×MΘX,ΛXCL×M

​ 对于L1=Λ,Θ[T(u)XXHW],我们有L1=trace(ΛT{Θ[T(u)XXHW]}),令B=Θ[T(u)XXHW],对于trace(ΛTB)关于B的偏导为trace(Λ)B关于X的导数为[OM×MIM×LIL×MOL×L],那么对应L1关于X的偏导为trace([ΛXΛT(u)ΛWΛXH])

​ 对于L2=ρ2Θ[T(u)XXHW]F2,我们有L2=trace((ΘB)(ΘB)H)=trace(ΘΘHΘBHBΘH+BBH),我们对L2关于X求偏导可以得到其偏导数为
2trace({ΘB}[OM×MIM×LIL×MOL×L])=2trace([XΘXTuΘTuWΘWXHΘXH])

​ 那么,我们可以得到L关于X的偏导为ρ(2X2ΘX)2ΛX+XY=0,因而在第一步迭代可以更新X如下:

Xk+1=Y+2ΛX(k)+2ΘX(k)1+2ρ

3、下面,我们继续对L中关于W的项求偏导,可以得到以下形式:

trace(τ2IL×L+[ΛT(u)ΛXΛXHΛW][OOOIW]+ρ[ΘT(u)TuΘXXΘXHXHΘWW][OOOIW])

取迹后,我们可以得到关于W的更新式如下:

τ2IL×LΛW+ρWρΘW=0W=ρ1ΛW+ΘWτ2ρIL×L

​ 4、下面,我们继续对L中关于T(u)的项求偏导,可以得到其更新式如下:

τ2IM×M+[ΛT(u)ΛXΛXHΛW][IM×MOOO]+ρ[ΘT(u)TuΘXXΘXHXHΘWW][IM×MOOO]=0τ2IM×MΛTu+ρ(TuΘT(u))=0Tu+=τ2ρIM×M+1ρΛT(u)+ΘT(u)

​ 5、下面,我们继续对L中关于Θ的项求偏导,这项比较特殊,因为我们可以将含Θ的项转化为以下形式:

<Λ,Θ[TuXXHW]>+ρ2Θ[TuXXHW]F2=ρ2Θ[TuXXHW]+ρ1ΛF2+const

​ 那么,对应我们可以得到Θ[TuXXHW]ρ1Λ时取到极值点。

​ 6、对应乘子项的更新,同单快拍中的表述。我们在上面的表述中,没有显式地写出具体的l+1l次迭代的关系,这并不影响,可以参考单快拍算法中的步骤,这里只是为了码公式而进行了简化。

牛顿法

牛顿法是求解无约束优化问题的经典方法。

参考文献

[1] ADMM算法简介
[2] 近端梯度下降

posted @   信海  阅读(1066)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 提示词工程——AI应用必不可少的技术
· 地球OL攻略 —— 某应届生求职总结
· 字符编码:从基础到乱码解决
· SpringCloud带你走进微服务的世界
点击右上角即可分享
微信分享提示