Generative Modeling by Estimating Gradients of the Data Distribution

Song Y. and Ermon S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems (NIPS), 2019.

当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.

主要内容

Langevin dynamics

对于分布p(x), 我们可以通过下列方式迭代生成

x~t=x~t1+ϵ2xlogp(x~t1)+ϵzt,

其中x~0π(x)来自一个先验分布, ztN(0,I). 当步长ϵ0并且T+的时候, x~T可以认为是从p(x)中采样的样本.

注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.

Score Matching

通过上述的迭代可以发现, 我们只需要获得xlogp(x)即可采样x, 我们可以期望通过下面的方式, 通过一个网络sθ(x)来逼近xlogpdata(x):

minθ12Epdata(x)[sθ(x)xlogpdata(x)22],

但是在实际中, 先验logpdata(x)也是未知的, 幸运的是上述公式等价于:

minθEpdata(x)[tr(xsθ(x))+12sθ(x)22].

注: 见 score matching

Denoising Score Matching

一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以Epdata(x)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:

minθ12Eqσ(x~|x)pdata(x)[sθ(x~)xlogqσ(x~|x)22].

当优化得足够好的时候,

sθ(x)=xlogqσ(x),qσ(x~):=qσ(x~|x)pdata(x)dx.

实际中, 通常取qσ(x~|x)=N(x~|x,σ2I), 相当于在真实数据x上加了一个扰动, 当扰动足够小(σ足够小)的时候, qσ(x)pdata(x), 则sθ(x)xlogpdata(x).

注: 为啥期望部分要有pdata? 实际上上述目标和score matching依旧是等价的.

Noise Conditional Score Networks

Slow mixing of Langevin dynamics

假设pdata(x)=πp1(x)+(1π)p2(x), 且p1,p2的支撑集合是互斥的, 那么 xlogpdata(x)要么为xlogp1(x)或者xlogp2(x), 与π没有丝毫关联, 这会导致训练的结果与π也没有关联. 在实际中, 若p1,p2近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.

作者的想法是, 设计一个noise conditional score networks:

sθ(x,σ),

给定不同的σ其拟合不同扰动大小的pσ, 在采样中, 首先用大一点的σ, 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的σ能够为后面的采样提供更好更鲁棒的初始点.

损失函数

设定σi,i=1,2,,L, 且满足:

σ1σ2==σL1σL>1,

即一个等比例(缩小)的数列.
对于每个σ采用如下损失:

(θ;σ)=12Epdata(x)EN(x~|x,σI)[sθ(x~,σ)+x~xσ222].

注: x~qσ(x~|x)=x~xσ2.

于是总损失为

L(θ;{σi}i=1L):=1Li=1Lλ(σi)(θ;σi),

λ(σi)为权重系数.

Annealed Langevin dynamics

Input: {σi}i=1L,ϵ,T;

  1. 初始化x0;
  2. For i=1,2,,L do:
    • αiϵσi2/σL2;
    • For t=1,2,,T do:
      • 采样ztN(0,I);
      • xtxt1+αi2sθ(xt1,σ)+αizt;
    • x0xT;

Output: xT.

细节

  1. 关于参数λ(σ)的选择:
    作者推荐选择λ(σ)=σ2, 因为当优化到最优的时候, sθ(x,σ)21/σ, 故σ2(θ;σ)=12E[σsθ(x,σ)+x~xσ22], 其中σsθ(x,σ)1,x~xσN(0,I), 故σ2θ,σσ无关.

  2. 关于αiϵσi2/σL2:

对于一次Langevin dynamic, 其获得的信息为: αi2sθ(xt1,σ), 其噪声为αizt, 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)

αisθ(x,σi)2αiz,

当我们按照算法中的取法时, 我们有

αisθ(x,σi)2αiz22αisθ(x,σi)224σisθ(x,σi)22414.

故采用此策略能够保证SNR保持一个稳定的值.

代码

原文代码

posted @   馒头and花卷  阅读(940)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
点击右上角即可分享
微信分享提示