Song Y., Sohl-Dickstein J., Kingma D. P., Kumar A., Ermon S. and Poole B. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations (ICLR), 2021
概
从 stochastic differential equation (SDE) 角度看 diffusion models.
符号说明
- x(t),t∈[0,T] 为 x 在时间 t 的一个状态;
- pt(x)=p(x(t)), x 在时间 t 所服从的分布;
- pst(x(t)|x(s)),0≤s<t≤T, 从 x(s) 到 x(t) 的转移核 (transition kernel);
- sθ(x,t), 为 score ∇xlogpt(x) 的一个近似, 通常用神经网络拟合.
Wiener process
Wiener process X(t,w) 是这样的一个随机过程:
- X(0)=0;
- X(t+Δt)−X(t) 和 X(s) 是独立的 (感觉就是马氏性);
- X(t+Δt)−X(t)∼N(0,Δt), 服从方差为 Δt 的正态分布;
- limΔ→0X(t+Δt)=X(t), 关于 t 是连续的.
本文所关注的是带 drift μ 的 Wiener 随机过程:
X(t,w)=μt+σWt,
其中 Wt 服从一般的 Wiener process.
我们可以用下列的 SDE 来描述该随机过程中的增量 (一般形式):
dx=f(x,t)dt+G(x,t)dw,(SDE+)
其中
f(⋅,t):Rd→Rd,G(⋅,t):Rd→Rd×d.
其中 dw 特指一般 Wiener process 中的增量, 即 w(t+Δt)−w(t)∼N(0,Δt).
它的逆过程可以描述为:
dx={f(x,t)−∇⋅[G(x,t)G(x,t)T]−G(x,t)G(x,t)T∇xlogpt(x)}dt+G(x,t)dw.(SDE-)
主要内容
SMLD 和 DDPM 采用了:
- x(0)→x(T), 逐渐加噪的过程;
- x(T)→x(0), 逐步采样的过程.
而这两个方程可以看成是两个(正反) SDE 的离散过程.
反向采样
我们首先讲反向采样, 这样会更容易理解前向中的一些设计. 我们知道, 一旦有了 (SDE-) 和 score function ∇xlogpt(x), 就可以通过一些离散求解方法去逐步'生成'解 x(0) 了.

Numerical SDE solvers
有很多数值解法可以用于反向采样: Euler-Maruyama, stochastic Runge-Kutta methods, Ancestral sampling.
本文提出了一种 reverse diffusion sampling (Ancestral sampling 是这个的一特例):
- 对于
dx=f(x,t)dt+G(x,t)dw,
采用xi+1=xi+fi(xi)+Gizi,i=0,1,⋯,N−1
的更新方式;
- 类似地, 对于(简化)
dx={f(x,t)−G(x,t)G(x,t)T∇xlogpt(x)}dt+G(t)dw,
采用 (注意, 符号是反的)xi=xi+1−fi+1(xi+1)+Gi+1GTi+1∇xlogpi+1(xi+1)+Gi+1zi+1.
Predictor-corrector samplers
假设我们知道 ∇xlogpt(x) 或者它的一个近似 sθ(x,t). 我们就可以通过 score-based MCMC 来采样了, 比如 Langevin MCMC 和 HMC (here).
利用 Langevin MCMC, 步骤如下:
x←x+ϵ∇xlogp(x)+√2ϵz,zi.i.d.∼N(0,I),
其中 ϵ 为步长.
注: MCMC 采样的过程是保证连续采样的点最终趋向于分布 p(x), 而不是说整个流程产生点符合 inverse 随机过程 !
整体的 PC samplers 框架如下:

其中 Predictor 可以是任意的 numeric solvers, Corrector 是 MCMC. 这相当于, 通过数值求解随机过程, 但是由于存在误差, 可能导致实际的 xi 偏离它的分布, 故再通过 MCMC 进行纠正.
Probability Flow
这部分, 作者将 SDE 转换成了一个 ODE, 从而能够确定性地采样, 但是这部分内容没怎么看懂, 就只在这里记一笔. 需要注意的是, 和 SDE 不一样, 因为 ODE 不含随即项, 故我们可以通过现成的 black-box ODE solver 来求解方程, 并且通过给定不同的 x(T)∼pT, 便能有不同的解.
其大致流程如下:
xi=xi+1−fi+1(xi+1)+12Gi+1GTi+1sθ(xi+1,i+1),i=0,1,⋯,N−1.
条件采样
条件采样, 即给定 y(0), 我们希望从条件分布
p(x(0)|y(0))
中采样. 一般来说, 我们会通过贝叶斯公式得到
p(x(0)|y(0))=p(y(0)|x(0))p(x(0))p(y(0)),
但是我们通常难以估计先验 p(x(0)) 和 p(y(0)).
我们可以通过下列的 inverse-time SDE 来从 pt(x(t)|y) 中采样:
dx={f(x,t)−∇⋅[G(x,t)G(x,t)T]−G(x,t)G(x,t)T∇xlogpt(x(t)|y(0))}dt+G(x,t)dw.
又
∇xlogpt(x(t)|y(0))=∇xlogpt(x(t))≈sθ(x,t)+∇xlogpt(y(0)|x(t)),
故当 ∇xlogpt(y(0)|x(t)) 可知时, 我们就可以采样了.
接下来, 我们讨论 pt(y(0)|x(t)) 可估计和难以直接估计的情况
可估计的情况
- y(0) 为分类任务中的标签;
- 采样 x(t);
- 利用交叉熵损失 训练一个 time-dependent 分类器:
pt(y(0)|x(t)).
难以估计的情况
此时我们注意到:
∇xlogpt(x(t)|y)=∇xlog∫pt(x(t)|y(t),y(0))p(y(t)|y(0))dy(t).
我们给出下面两个合理的假设:
- p(y(t)|y(0)) 是可求的;
- pt(x(t)|y(t),y(0))≈pt(x(t)|y(t)), 这是因为对于 t 比较小的情况, y(t)≈y(0), 而对于 t 比较大的情况, x(t) 受 y(t) 影响最大.
此时有
∇xlogpt(x(t)|y(0))≈∇xlog∫pt(x(t)|y(t))p(y(t)|y(0))dy(t)≈logpt(x(t)|^y(t))←^y(t)∼p(y(t)|y(0))=∇logxpt(x(t))+∇xlogpt(^y(t)|x(t))≈sθ(x(t),t)+∇xlogpt(^y(t)|x(t)).
此时只要 ∇xlogpt(^y(t)|x(t)) 可知便可代入求解了.
下面以 Imputation 为例进行讲解. 假设 Ω(x),¯Ω(x) 分别表示 观测的 和 缺失的 部分. 我们的目的是从
p(x(0)|Ω(x(0))=y)
中采样. 按照上面的步骤, 我们只需要估计
∇xlogpt(x(t)|^Ω(x(t)))
即可. 实际上, 注意到由于本文的建模都是 element-wise 的, 所以
pt(x(t)|^Ω(x(t)))=pt(x^Ω(t)),
即仅 ^Ω 区域需要采样.
注: 这里的内容和原文 Appendix I.2 的推导有较大出入, 我是按照我自己的理解来的, 也没有实验过, 准确性存疑 !
前向扰动
根据前面的流程, 我们知道, 倘若我们能够估计出
sθ(x,t)≈∇xlogpt(x),
那么我们就可以跟着随机过程一步一步地采样了, 而这需要用到 (denosing) score matching 作为训练目标:
θ∗=argminθEt{λ(t)Ex(0)Ex(t)|x(0)[∥sθ(x(t),t)−∇x(t)logp0t(x(t)|x(0))∥22]},
其中 λ(⋅) 为正的权重, 通常选择 λ∝1/E[∥∇x(t)logp0t(x(t)|x(0))∥22], t∼U[0,T].
从上面目标函数的定义可知, 一般来说, 只有 p0t 是显式可求的上面的才有意义, 对于更加一般的随机过程, 可以用 slice score matching 来绕开其中复杂的计算 (不过需要以更多的计算量为代价). 下面所介绍的, 都是可求的高斯分布.
SMLD
SMLD 定义了 {xi}Ni=1, 可以看成是 t=iN∈[0,T=1] 的离散的随机过程:
xi=xi−1+√σ2i−σ2i−1zi−1,zii.i.d.∼N(0,I).(1)
且满足
σmin=σ1<σ2<⋯<σN=σmax.
此时有:
xi|x0∼N(x0,σ2iI).
我们进一步将其改写成 SDE 的形式 (即令 N→∞ ):
Δx(t)=x(t+Δ)−x(t)=√Δσ2(t)z(t)=√Δσ2(t)ΔtΔtz(t),
当 Δt→0 时 (即 N→∞ ) 有:
Δx(t)→dx(t),Δσ2(t)Δt→d[σ2(t)]dt.
最后, 我们容易发现增量 √Δtz(t)∼N(0,Δt), 所构成的随机过程自然满足 Wiener process, 故
dx=0dt+√dσ2(t)dtdw.(2)
即不存在 drift 量.
DDPM
DDPM 定义了 {xi}Ni=1, 可以看成是 t=iN∈[0,T=1] 的离散的随机过程:
xi=√1−βixi−1+√βizi−1,zii.i.d.∼N(0,I).(3)
令 ¯βi:=Nβi, 并定义
β(t),t∈[0,1],β(iN)=¯βi.
则 (3) 可以改写为
x(t+Δt)−x(t)=(√1−β(t+Δt)Δt−1)x(t)+√β(t+Δt)Δtz(t),(3+)
当 Δ→0, 有
x(t+Δt)−x(t)=Δx(t)→dx(t)√1−β(t+Δt)Δt−1→−12β(t)dt√β(t+Δt)Δtz(t)→√β(t)dw.
其中第二项由一阶泰勒近似可以得到, 第二项和 SMLD 中的推理是类似的.
最后, 可以总结为如下的 Wiener process:
dx=−12β(t)xdt+√β(t)dw.(4)
接下来我们推导一下 DDPM 的 x(t) 的条件分布. (3+) 两边取期望可知
e(t+Δt)−e(t)=(√1−β(t+Δt)Δt−1)e(t)+0,
其中 e(t)=E[x(t)], 则
de=−12β(t)edt,
加上初值条件 e(0)=e0, 可得:
e(t)=e(0)e−12∫t0β(s)ds.
而 x(t) 的协方差矩阵 ΣVP(t) 满足
dΣVP(t)=β(t)(I−ΣVP(t))dt,
加上初始值 ΣVP(0)可得
ΣVP(t)=I+e−∫t0β(s)ds(ΣVP(0)−I).
故服从
x(t)|e(0)∼N(e(0)e−12∫t0β(s)ds;I+e−∫t0β(s)ds(ΣVP(0)−I))
在已知 x(0) 的条件下, e(0)=x(0),ΣVP(0)=0, 故
x(t)|x(0)∼N(x(0)e−12∫t0β(s)ds;I−e−∫t0β(s)dsI)
注: 方差的公式的推导在另一篇论文中, 这里的方差求解是一般的基础的.
拓展
通过 SMLD 和 DDPM 两个例子可以发现, 我们只需要个性化定制 f(x,t) 和 G(x,t), 即可构造不同的前向扰动过程. 实际上, SMLD 和 DDPM 代表了两种不同的 SDE: Variance Exploding (VE) SDE 和 Variance Preserving (VP) SDE. 这是因为 SMLD 要求 σmax→∞ 而由上面的推导可得, 倘若 ΣVP(0)=I 或者 ∫t0β(s)ds→+∞时, 方差都是收敛的.
sub-VP SDE
受 DDPM VP SDE 性质的启发, 作者设计了一种新的前向扰动过程:
dx=−12β(t)xdt+√β(t)(1−e−2∫t0β(s)ds)dw.
和 DDPM 一样, x(t) 的期望
E[x(t)]=E[x(0)]e−12∫t0β(s)ds.
而协方差为
Σsub−VP(t):=Cov[x(t)]=I+e−2∫t0β(s)dsI+e−∫t0β(s)ds(Σsub−VP(0)−2I).
它有两个性质:
- 当 ΣVP(0)=Σsub−VP(0)时, Σsub−VP⪯ΣVP, 即拥有更小的方差;
- limt→Σsub−VP(t)=I 当 ∫+∞0β(s)ds=+∞.
此外它的条件分布为:
x(t)|x(0)∼N(x(0)e−12∫t0β(s)ds;(1−e−∫t0β(s)ds)2I).
具体的采样算法
PC sampling

Corrector
这里, 作者直接构造步长, 需要注意的是, 这里的 r 代表信噪比.

其它细节
- 网络结构: 和 DDPM 中的一致;
- 训练采用 N=1000 scales;
- 采样的时候, 最后得到的 x(0) 会带有人眼无法察觉但是影响 FID 指标的噪声, 故需要在结束的时候和 DDPM 一样接入去噪环节 (Tweedies' formula);
- 虽然训练的时候采取 N=1000, 但是采样的时候可以 N=2000 甚至更多, 这个时候需要插值, 比如
s′θ(x,i)→s′θ(x,i/2),s′θ(x,i)→s′θ(x,⌊i/2⌋);
- 最优的 信噪比 (singal-to-noise) r 如下图所示:

代码
[official]
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
2021-06-21 Local Relation Networks for Image Recognition