Score-Based Generative Modeling through Stochastic Differential Equations

Song Y., Sohl-Dickstein J., Kingma D. P., Kumar A., Ermon S. and Poole B. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations (ICLR), 2021

从 stochastic differential equation (SDE) 角度看 diffusion models.

符号说明

  • x(t),t[0,T]x 在时间 t 的一个状态;
  • pt(x)=p(x(t)), x 在时间 t 所服从的分布;
  • pst(x(t)|x(s)),0s<tT, 从 x(s)x(t) 的转移核 (transition kernel);
  • sθ(x,t), 为 score xlogpt(x) 的一个近似, 通常用神经网络拟合.

Wiener process

Wiener process X(t,w) 是这样的一个随机过程:

  1. X(0)=0;
  2. X(t+Δt)X(t)X(s) 是独立的 (感觉就是马氏性);
  3. X(t+Δt)X(t)N(0,Δt), 服从方差为 Δt 的正态分布;
  4. limΔ0X(t+Δt)=X(t), 关于 t 是连续的.

本文所关注的是带 drift μ 的 Wiener 随机过程:

X(t,w)=μt+σWt,

其中 Wt 服从一般的 Wiener process.

我们可以用下列的 SDE 来描述该随机过程中的增量 (一般形式):

(SDE+)dx=f(x,t)dt+G(x,t)dw,

其中

f(,t):RdRd,G(,t):RdRd×d.

其中 dw 特指一般 Wiener process 中的增量, 即 w(t+Δt)w(t)N(0,Δt).

它的逆过程可以描述为:

(SDE-)dx={f(x,t)[G(x,t)G(x,t)T]G(x,t)G(x,t)Txlogpt(x)}dt+G(x,t)dw.

主要内容

SMLDDDPM 采用了:

  1. x(0)x(T), 逐渐加噪的过程;
  2. x(T)x(0), 逐步采样的过程.

而这两个方程可以看成是两个(正反) SDE 的离散过程.

反向采样

我们首先讲反向采样, 这样会更容易理解前向中的一些设计. 我们知道, 一旦有了 (SDE-) 和 score function xlogpt(x), 就可以通过一些离散求解方法去逐步'生成'解 x(0) 了.

Numerical SDE solvers

有很多数值解法可以用于反向采样: Euler-Maruyama, stochastic Runge-Kutta methods, Ancestral sampling.

本文提出了一种 reverse diffusion sampling (Ancestral sampling 是这个的一特例):

  1. 对于

    dx=f(x,t)dt+G(x,t)dw,

    采用

    xi+1=xi+fi(xi)+Gizi,i=0,1,,N1

    的更新方式;
  2. 类似地, 对于(简化)

    dx={f(x,t)G(x,t)G(x,t)Txlogpt(x)}dt+G(t)dw,

    采用 (注意, 符号是的)

    xi=xi+1fi+1(xi+1)+Gi+1Gi+1Txlogpi+1(xi+1)+Gi+1zi+1.

Predictor-corrector samplers

假设我们知道 xlogpt(x) 或者它的一个近似 sθ(x,t). 我们就可以通过 score-based MCMC 来采样了, 比如 Langevin MCMC 和 HMC (here).

利用 Langevin MCMC, 步骤如下:

xx+ϵxlogp(x)+2ϵz,zi.i.d.N(0,I),

其中 ϵ 为步长.

注: MCMC 采样的过程是保证连续采样的点最终趋向于分布 p(x), 而不是说整个流程产生点符合 inverse 随机过程 !

整体的 PC samplers 框架如下:

其中 Predictor 可以是任意的 numeric solvers, Corrector 是 MCMC. 这相当于, 通过数值求解随机过程, 但是由于存在误差, 可能导致实际的 xi 偏离它的分布, 故再通过 MCMC 进行纠正.

Probability Flow

这部分, 作者将 SDE 转换成了一个 ODE, 从而能够确定性地采样, 但是这部分内容没怎么看懂, 就只在这里记一笔. 需要注意的是, 和 SDE 不一样, 因为 ODE 不含随即项, 故我们可以通过现成的 black-box ODE solver 来求解方程, 并且通过给定不同的 x(T)pT, 便能有不同的解.

其大致流程如下:

xi=xi+1fi+1(xi+1)+12Gi+1Gi+1Tsθ(xi+1,i+1),i=0,1,,N1.

条件采样

条件采样, 即给定 y(0), 我们希望从条件分布

p(x(0)|y(0))

中采样. 一般来说, 我们会通过贝叶斯公式得到

p(x(0)|y(0))=p(y(0)|x(0))p(x(0))p(y(0)),

但是我们通常难以估计先验 p(x(0))p(y(0)).

我们可以通过下列的 inverse-time SDE 来从 pt(x(t)|y) 中采样:

dx={f(x,t)[G(x,t)G(x,t)T]G(x,t)G(x,t)Txlogpt(x(t)|y(0))}dt+G(x,t)dw.

xlogpt(x(t)|y(0))=xlogpt(x(t))sθ(x,t)+xlogpt(y(0)|x(t)),

故当 xlogpt(y(0)|x(t)) 可知时, 我们就可以采样了.

接下来, 我们讨论 pt(y(0)|x(t)) 可估计和难以直接估计的情况

可估计的情况
  1. y(0) 为分类任务中的标签;
  2. 采样 x(t);
  3. 利用交叉熵损失 训练一个 time-dependent 分类器:

    pt(y(0)|x(t)).

难以估计的情况

此时我们注意到:

xlogpt(x(t)|y)=xlogpt(x(t)|y(t),y(0))p(y(t)|y(0))dy(t).

我们给出下面两个合理的假设:

  1. p(y(t)|y(0)) 是可求的;
  2. pt(x(t)|y(t),y(0))pt(x(t)|y(t)), 这是因为对于 t 比较小的情况, y(t)y(0), 而对于 t 比较大的情况, x(t)y(t) 影响最大.

此时有

xlogpt(x(t)|y(0))xlogpt(x(t)|y(t))p(y(t)|y(0))dy(t)logpt(x(t)|y^(t))y^(t)p(y(t)|y(0))=logxpt(x(t))+xlogpt(y^(t)|x(t))sθ(x(t),t)+xlogpt(y^(t)|x(t)).

此时只要 xlogpt(y^(t)|x(t)) 可知便可代入求解了.

下面以 Imputation 为例进行讲解. 假设 Ω(x),Ω¯(x) 分别表示 观测的 和 缺失的 部分. 我们的目的是从

p(x(0)|Ω(x(0))=y)

中采样. 按照上面的步骤, 我们只需要估计

xlogpt(x(t)|Ω^(x(t)))

即可. 实际上, 注意到由于本文的建模都是 element-wise 的, 所以

pt(x(t)|Ω^(x(t)))=pt(xΩ^(t)),

即仅 Ω^ 区域需要采样.

注: 这里的内容和原文 Appendix I.2 的推导有较大出入, 我是按照我自己的理解来的, 也没有实验过, 准确性存疑 !

前向扰动

根据前面的流程, 我们知道, 倘若我们能够估计出

sθ(x,t)xlogpt(x),

那么我们就可以跟着随机过程一步一步地采样了, 而这需要用到 (denosing) score matching 作为训练目标:

θ=argminθEt{λ(t)Ex(0)Ex(t)|x(0)[sθ(x(t),t)x(t)logp0t(x(t)|x(0))22]},

其中 λ() 为正的权重, 通常选择 λ1/E[x(t)logp0t(x(t)|x(0))22], tU[0,T].

从上面目标函数的定义可知, 一般来说, 只有 p0t 是显式可求的上面的才有意义, 对于更加一般的随机过程, 可以用 slice score matching 来绕开其中复杂的计算 (不过需要以更多的计算量为代价). 下面所介绍的, 都是可求的高斯分布.

SMLD

SMLD 定义了 {xi}i=1N, 可以看成是 t=iN[0,T=1] 的离散的随机过程:

(1)xi=xi1+σi2σi12zi1,zii.i.d.N(0,I).

且满足

σmin=σ1<σ2<<σN=σmax.

此时有:

xi|x0N(x0,σi2I).

我们进一步将其改写成 SDE 的形式 (即令 N ):

Δx(t)=x(t+Δ)x(t)=Δσ2(t)z(t)=Δσ2(t)ΔtΔtz(t),

Δt0 时 (即 N ) 有:

Δx(t)dx(t),Δσ2(t)Δtd[σ2(t)]dt.

最后, 我们容易发现增量 Δtz(t)N(0,Δt), 所构成的随机过程自然满足 Wiener process, 故

(2)dx=0dt+dσ2(t)dtdw.

即不存在 drift 量.

DDPM

DDPM 定义了 {xi}i=1N, 可以看成是 t=iN[0,T=1] 的离散的随机过程:

(3)xi=1βixi1+βizi1,zii.i.d.N(0,I).

β¯i:=Nβi, 并定义

β(t),t[0,1],β(iN)=βi¯.

则 (3) 可以改写为

(3+)x(t+Δt)x(t)=(1β(t+Δt)Δt1)x(t)+β(t+Δt)Δtz(t),

Δ0, 有

x(t+Δt)x(t)=Δx(t)dx(t)1β(t+Δt)Δt112β(t)dtβ(t+Δt)Δtz(t)β(t)dw.

其中第二项由一阶泰勒近似可以得到, 第二项和 SMLD 中的推理是类似的.

最后, 可以总结为如下的 Wiener process:

(4)dx=12β(t)xdt+β(t)dw.

接下来我们推导一下 DDPM 的 x(t) 的条件分布. (3+) 两边取期望可知

e(t+Δt)e(t)=(1β(t+Δt)Δt1)e(t)+0,

其中 e(t)=E[x(t)], 则

de=12β(t)edt,

加上初值条件 e(0)=e0, 可得:

e(t)=e(0)e120tβ(s)ds.

x(t) 的协方差矩阵 ΣVP(t) 满足

dΣVP(t)=β(t)(IΣVP(t))dt,

加上初始值 ΣVP(0)可得

ΣVP(t)=I+e0tβ(s)ds(ΣVP(0)I).

故服从

x(t)|e(0)N(e(0)e120tβ(s)ds;I+e0tβ(s)ds(ΣVP(0)I))

在已知 x(0) 的条件下, e(0)=x(0),ΣVP(0)=0, 故

x(t)|x(0)N(x(0)e120tβ(s)ds;Ie0tβ(s)dsI)

注: 方差的公式的推导在另一篇论文中, 这里的方差求解是一般的基础的.

拓展

通过 SMLD 和 DDPM 两个例子可以发现, 我们只需要个性化定制 f(x,t)G(x,t), 即可构造不同的前向扰动过程. 实际上, SMLD 和 DDPM 代表了两种不同的 SDE: Variance Exploding (VE) SDE 和 Variance Preserving (VP) SDE. 这是因为 SMLD 要求 σmax 而由上面的推导可得, 倘若 ΣVP(0)=I 或者 0tβ(s)ds+时, 方差都是收敛的.

sub-VP SDE

受 DDPM VP SDE 性质的启发, 作者设计了一种新的前向扰动过程:

dx=12β(t)xdt+β(t)(1e20tβ(s)ds)dw.

和 DDPM 一样, x(t) 的期望

E[x(t)]=E[x(0)]e120tβ(s)ds.

而协方差为

ΣsubVP(t):=Cov[x(t)]=I+e20tβ(s)dsI+e0tβ(s)ds(ΣsubVP(0)2I).

它有两个性质:

  1. ΣVP(0)=ΣsubVP(0)时, ΣsubVPΣVP, 即拥有更小的方差;
  2. limtΣsubVP(t)=I0+β(s)ds=+.

此外它的条件分布为:

x(t)|x(0)N(x(0)e120tβ(s)ds;(1e0tβ(s)ds)2I).

具体的采样算法

PC sampling

Corrector

这里, 作者直接构造步长, 需要注意的是, 这里的 r 代表信噪比.

其它细节

  • 网络结构: 和 DDPM 中的一致;
  • 训练采用 N=1000 scales;
  • 采样的时候, 最后得到的 x(0) 会带有人眼无法察觉但是影响 FID 指标的噪声, 故需要在结束的时候和 DDPM 一样接入去噪环节 (Tweedies' formula);
  • 虽然训练的时候采取 N=1000, 但是采样的时候可以 N=2000 甚至更多, 这个时候需要插值, 比如

sθ(x,i)sθ(x,i/2),sθ(x,i)sθ(x,i/2);

  • 最优的 信噪比 (singal-to-noise) r 如下图所示:

代码

[official]

posted @   馒头and花卷  阅读(4007)  评论(5编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
历史上的今天:
2021-06-21 Local Relation Networks for Image Recognition
点击右上角即可分享
微信分享提示