SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION论文阅读笔记

Intro

在训练集上最小化损失很可能导致泛化性低,因为当今模型的过参数化会导致training loss的landscape异常复杂且非凸,包含很多local/global minima,因此优化器的选择至关重要。loss landscape的几何性质(特别是minima的flatness)与泛化性有着紧密的联系,为此作者提出了SAM(Sharpness-Aware Minimization),通过寻找位于具有一致低损失值的邻域中的参数(而不是仅本身具有低损失值的参数)以提升模型的泛化性。

SHARPNESS-AWARE MINIMIZATION (SAM)

令标量为α,向量为α,矩阵为A,集合为A,“定义为”表示为,给定来自分布D的训练集S{(xi,yi)},训练集的损失表示为LS(ω)1ni=1nl(ω,xi,yi),泛化误差表示为LD(ω)E(x,y)D[l(ω,x,y)]

由于模型只能看到训练集,因此通常的做法是让训练损失近可能小,然而这可能导致测试时的性能不佳。为此作者提出了SAM,不去寻找带来最小训练损失的参数,而是寻找整个邻域都具有一致低训练损失的参数值(邻域具有低损失和低曲率)。

Theorem (stated informally) 1.

对于任意ρ>0,生成的训练集大概率满足:

LD(ω)max||ϵ||2ρLS(ω+ϵ)+h(||ω||22/ρ2)

其中h:R+R+是严格单调递增函数。证明位于附录A。

因此,为了使泛化损失近可能小,我们可以近可能减小其上界,而右边的项带有一个max,所以这构成了一个min-max问题。为了明确和sharpness有关的项,可以将不等式右边写为:

[max||ϵ||2ρLS(ω+ϵ)LS(ω)]+LS(ω)+h(||ω||22/ρ2)

中括号中的部分表示的就是LS的锐度。鉴于右边的h函数很大程度上受到证明细节的影响,这里作者将其写为标准的正则化项λ||ω||22,通过超参数λ加以控制。由此,作者提出通过求解SharpnessAware Minimization问题来进行参数的选择:

minωLSSAM(ω)+λ||ω||22

其中LSSAM(ω)max||ϵ||pρLS(ω+ϵ)ρ0为超参数,p[1,]p的值取2是最优的)。

为了最小化LSSAM,作者通过对inner maximization求微分来得到ωLSSAM(ω)的近似,这让我们能够通过SGD实现SAM的优化目标。为此,作者首先对LS(ω+ϵ)ϵ0进行一阶泰勒展开:

ϵ(ω)argmax||ϵ||pρLS(ω+ϵ)argmax||ϵ||pρLS(ω)+ϵϵLS(ω)=argmax||ϵ||pρϵϵLS(ω)

优化问题的解可以通过求解经典的对偶范数问题得到:

ϵ^(ω)=ρsign(ωLS(ω))|ωLS(ω)|q1/(||ωLS(ω)||qq)1/p

其中1/p+1/q=1。代入p=2这个最优的值(q=2) 计算ϵ^(ω),之后将其回代到前面的公式,可以得到:
截屏2024-01-13 17.35.30

其中第二个等号通过复合微分的运算法则得到。为了加速计算,将二阶项丢掉,就可以得到最后的梯度近似:

截屏2024-01-13 17.41.47

伪代码和示意图:

截屏2024-01-13 17.25.44

实验

截屏2024-01-13 17.46.04

等等

参考:https://blog.csdn.net/qq_40744423/article/details/121570423

posted @   脂环  阅读(386)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
历史上的今天:
2023-01-13 学术规范与论文写作1&2
点击右上角即可分享
微信分享提示
主题色彩