Typesetting math: 74%

【论文笔记 - DDPM】Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Read-through

Abstract

We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN. Our implementation is available at https://github.com/hojonathanho/diffusion.

数学推导【转载】

数学推导过程来自苏剑林大神的《生成扩散模型漫谈》系列,感谢苏神的无私奉献,深入浅出的讲解让我这样数学功底不好的人也能领略这个当下最为火爆的模型的精髓。

系列中有部分步骤,一眼看过去可能有些费解,所以这里稍微做了展开,作为自己的笔记用。

通俗解释:DDPM=拆楼+建楼

生成模型实际上就是:随机噪声 z z  样本数据 xx

我们把“拆楼”分为 TT 步:

x=x0x1x2xT1xT=zx=x0x1x2xT1xT=z(1)

如果能学会 xt1=μ(xt)xt1=μ(xt),那么反复执行 xT1=μ(xT),xT2=μ(xT1),,x1=μ(x0)xT1=μ(xT),xT2=μ(xT1),,x1=μ(x0) 即可还原 x0x0

该怎么拆

DDPM将“拆楼”建模为:

xt=αtxt1+βtεt,εtN(0,I)xt=αtxt1+βtεt,εtN(0,I)(2)

其中 αt,βt>0αt,βt>0α2t+β2t=1α2t+β2t=1,通常 βt0βt0εtεt 为噪声。

反复执行这个拆楼的步骤,可以得到:

xt=αtxt1+βtεt=αt(αt1xt2+βt1εt1)+βtεt==(αtα1)x0+(αtα2)β1ε1+(αtα3)β2ε2++αtβt1εt1+βtεt多个相互独立的正态噪声之和xt=αtxt1+βtεt=αt(αt1xt2+βt1εt1)+βtεt==(αtα1)x0+(αtα2)β1ε1+(αtα3)β2ε2++αtβt1εt1+βtεt(3)

式中花括号指出的部分可以看成一个整体的噪声。利用正态分布的叠加性:

X1N(μ1,σ21),X2N(μ2,σ22)X1+X2N(μ1+μ2,σ21+σ22)X1N(μ1,σ21),X2N(μ2,σ22)X1+X2N(μ1+μ2,σ21+σ22)(4)

显然这些噪声的均值为 00,我们来算它们的方差之和:

(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2+α2tβ2t1+β2t=(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2+α2tβ2t1α2t+1=(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2(αtαt1)2+1==(αtα1)2+(αtα2)2β21(αtα2)2+1=(αtα1)2(αtα1)2+1=1(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2+α2tβ2t1+β2t=(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2+α2tβ2t1α2t+1=(αtα1)2+(αtα2)2β21++(αtαt1)2β2t2(αtαt1)2+1==(αtα1)2+(αtα2)2β21(αtα2)2+1=(αtα1)2(αtα1)2+1=1(5)

所以实际上相当于有:

xt=(αtα1)记为ˉαtx0+1(αtα1)2记为ˉβtˉεt,ˉεtN(0,I)=ˉαtx0+ˉβtˉεt,ˉεtN(0,I)xt=(αtα1)¯αtx0+1(αtα1)2¯βt¯εt,¯εtN(0,I)=¯αtx0+¯βt¯εt,¯εtN(0,I)(6)

此外DDPM还会选择适当的 αtαt,使得 ˉαT0¯αT0,这意味着经过 TT 步的拆楼后,所剩的楼体几乎可以忽略了,已经全部转化为原材料 εε

又如何建

xt1xtxt1xt 有了,现在我们要学习 xtxt1xtxt1。设该模型为 μ(xt)μ(xt),那么学习方案就是最小化欧氏距离:

xt1μ(xt)2xt1μ(xt)2(7)

首先根据 (2)(2),反解 xt1xt1 就是 xt1=1αt(xtβtεt)xt1=1αt(xtβtεt)。所以我们就可以将 μ(xt)μ(xt) 设成:

μ(xt)=1αt(xtβtϵθ(xt,t))μ(xt)=1αt(xtβtϵθ(xt,t))(8)

其中 θθ 是训练参数。代入到 (7)(7) 中,损失函数即:

xt1μ(xt)2=1αt(xtβtεt)1αt(xtβtϵθ(xt,t))2=β2tα2tεtϵθ(xt,t)2xt1μ(xt)2=1αt(xtβtεt)1αt(xtβtϵθ(xt,t))2=β2tα2tεtϵθ(xt,t)2(9)

忽略掉权重 β2tα2tβ2tα2t,另外结合 (2) (6)(2) (6) 可以将 xtxt 化为:

xt=αtxt1+βtεt=αt(ˉαt1x0+ˉβt1ˉεt1)+βtεt=ˉαtx0+αtˉβt1ˉεt1+βtεtxt=αtxt1+βtεt=αt(¯αt1x0+¯βt1¯εt1)+βtεt=¯αtx0+αt¯βt1¯εt1+βtεt(10)

最终损失函数的形式为:

εtϵθ(xt,t)2=εtϵθ(ˉαtx0+αtˉβt1ˉεt1+βtεt,t)2εtϵθ(xt,t)2=εtϵθ(¯αtx0+αt¯βt1¯εt1+βtεt,t)2(11)

为什么需要 xt=αtxt1+βtεtxt=αtxt1+βtεt 这一步呢?这是因为 ˉεt¯εtεtεt 不是相互独立的,所以只能用 ˉεt1¯εt1εtεt

降低方差

原则上 (11)(11) 就可以完成DDPM的训练,但由于这个式子中需要对 x0,ˉεt1,εt,tx0,¯εt1,εt,t 四个随机变量分别采样,在实践中可能有方差过大的风险,从而导致收敛过慢等问题。我们可以将 ˉεt1,εt¯εt1,εt 合并成单个随机变量,从而缓解方差大的问题。

首先推一下 ˉβ2t1¯β2t1β2t,ˉβ2tβ2t,¯β2t 的关系:

ˉβ2t1=1ˉα2t1=1ˉα2tα2t=11ˉβ2t1β2t=ˉβ2tβ2t1β2t¯β2t1=1¯α2t1=1¯α2tα2t=11¯β2t1β2t=¯β2tβ2t1β2t(12)

然后和上面做过的事情一样,利用一下正态分布的叠加性:

αtˉβt1ˉεt1+βtεtαt¯βt1¯εt1+βtεt 均值为 00,方差为 α2tˉβ2t1+β2t=α2tˉβ2tβ2t1β2t+β2t=β2tα2t¯β2t1+β2t=α2t¯β2tβ2t1β2t+β2t=β2t,实际相当于 ˉβtε|εN(0,I)¯βtε|εN(0,I)

βtˉεt1αtˉβt1εtβt¯εt1αt¯βt1εt 均值为 00,方差为 α2tˉβ2t1+β2t=α2tˉβ2tβ2t1β2t+β2t=β2tα2t¯β2t1+β2t=α2t¯β2tβ2t1β2t+β2t=β2t,实际相当于 ˉβtω|ωN(0,I)¯βtω|ωN(0,I)

然后我们来验证一下 εεωω 是两个相互独立的正态随机变量。这可以用协方差为零证明,不过我们也可以通过 E[εω]=0E[εω]=0 来说明。我们先算 E[(ˉβtε)(ˉβtω)]E[(¯βtε)(¯βtω)]

E[(ˉβtε)(ˉβtω)]=E[(αtˉβt1ˉεt1+βtεt)(βtˉεt1αtˉβt1εt)]=E[αtβtˉβt1ˉεt1ˉεt1α2tˉβ2t1ˉεt1εt+β2tεtˉεt1αtβtˉβt1εtεt]=αtβtˉβt1I0+0αtβtˉβt1I=0E[(¯βtε)(¯βtω)]=E[(αt¯βt1¯εt1+βtεt)(βt¯εt1αt¯βt1εt)]=E[αtβt¯βt1¯εt1¯εt1α2t¯β2t1¯εt1εt+β2tεt¯εt1αtβt¯βt1εtεt]=αtβt¯βt1I0+0αtβt¯βt1I=0(13)

于是我们也就证明了 E[εω]=0E[εω]=0。这里用到了结论:ε1,ε2N(0,I)ε1,ε2N(0,I) ,且 ε1,ε2ε1,ε2 独立,则有 E[ε1ε1]=E[ε2ε2]=IE[ε1ε1]=E[ε2ε2]=IE[ε1ε2]=E[ε1ε2]=0E[ε1ε2]=E[ε1ε2]=0

接下来我们反过来解 εtεt

{αtˉβt1ˉεt1+βtεt=ˉβtεβtˉεt1αtˉβt1εt=ˉβtω{αt¯βt1¯εt1+βtεt=¯βtεβt¯εt1αt¯βt1εt=¯βtω(14)

解得:

εt=αtˉβt1ˉβtωβtˉβtεα2tˉβ2t1β2t=(αtˉβt1ωβtε)ˉβtα2tˉβ2tβ2t1β2tβ2t=(βtεαtˉβt1ω)ˉβtˉβ2t=βtεαtˉβt1ωˉβtεt=αt¯βt1¯βtωβt¯βtεα2t¯β2t1β2t=(αt¯βt1ωβtε)¯βtα2t¯β2tβ2t1β2tβ2t=(βtεαt¯βt1ω)¯βt¯β2t=βtεαt¯βt1ω¯βt(15)

代入 (11)(11) 式得到:

Eˉεt1,εtN(0,I)[εtϵθ(ˉαtx0+αtˉβt1ˉεt1+βtεt,t)2]=Eω,εN(0,I)[βtεαtˉβt1ωˉβtϵθ(ˉαtx0+ˉβtε,t)2]E¯εt1,εtN(0,I)[εtϵθ(¯αtx0+αt¯βt1¯εt1+βtεt,t)2]=Eω,εN(0,I)[βtεαt¯βt1ω¯βtϵθ(¯αtx0+¯βtε,t)2](16)

我们先来处理 ωω

Eω,εN(0,I)[βtεαtˉβt1ωˉβtϵθ(ˉαtx0+ˉβtε,t)2]=Eω,εN(0,I)[αtˉβt1ˉβtω+βtˉβtεϵθ(ˉαtx0+ˉβtε,t)2]=Eω,εN(0,I)[Aω+B2]=Eω,εN(0,I)[Aω2+2ABω+B2]Eω,εN(0,I)[βtεαt¯βt1ω¯βtϵθ(¯αtx0+¯βtε,t)2]=Eω,εN(0,I)[αt¯βt1¯βtω+βt¯βtεϵθ(¯αtx0+¯βtε,t)2]=Eω,εN(0,I)[Aω+B2]=Eω,εN(0,I)[Aω2+2ABω+B2](17)

直接打开,注意到 EωEωEω2Eω2 都是常数,所以损失函数就相当于:

EεN(0,I)[B2]+常数=EεN(0,I)[βtˉβtεϵθ(ˉαtx0+ˉβtε,t)2]+常数=β2tˉβ2tEεN(0,I)[εˉβtβtϵθ(ˉαtx0+ˉβtε,t)2]+常数EεN(0,I)[B2]+=EεN(0,I)[βt¯βtεϵθ(¯αtx0+¯βtε,t)2]+=β2t¯β2tEεN(0,I)[ε¯βtβtϵθ(¯αtx0+¯βtε,t)2]+(18)

再次忽略常数和权重,我们得到DDPM最终所用的损失函数:

EεN(0,I)εˉβtβtϵθ(ˉαtx0+ˉβtε,t)2EεN(0,I)ε¯βtβtϵθ(¯αtx0+¯βtε,t)2(19)

这个形式和DDPM原论文中的 Lsimple(θ)Lsimple(θ) 是完全一致的:

Lsimple(θ):=Et,x0,ϵϵϵθ(ˉαtx0+1ˉαtϵ,t)2Lsimple(θ):=Et,x0,ϵϵϵθ(¯αtx0+1¯αtϵ,t)2(20)

递归生成

训练完之后,我们就可以从一个随机噪声 xTN(0,I)xTN(0,I) 出发执行 TT(8)(8) 来进行生成:

xt1=1αt(xtβtϵθ(xt,t))xt1=1αt(xtβtϵθ(xt,t))(21)

这对应于自回归解码中的Greedy Search。如果要进行Random Sample,那么需要补上噪声项:

xt1=1αt(xtβtϵθ(xt,t))+σtz,zN(0,I)xt1=1αt(xtβtϵθ(xt,t))+σtz,zN(0,I)(22)

一般来说,我们可以让 σt=βtσt=βt,即正向和反向的方差保持同步。

超参设置

在DDPM中,T=1000T=1000αt=10.02tTαt=10.02tT

在重构的时候我们用了欧氏距离 (7)(7) 作为损失函数,而一般我们用DDPM做图片生成,以往做过图片生成的读者都知道,欧氏距离并不是图片真实程度的一个好的度量,VAE用欧氏距离来重构时,往往会得到模糊的结果,除非是输入输出的两张图片非常接近,用欧氏距离才能得到比较清晰的结果,所以选择尽可能大的 TT,正是为了使得输入输出尽可能相近,减少欧氏距离带来的模糊问题。

为什么要选择单调递减的 αtαt 呢?当 tt 比较小时,xtxt 还比较接近真实图片,所以我们要缩小 xt1xt1xtxt 的差距,以便更适用欧氏距离 (7)(7),因此要用较大的 αtαt;当 tt 比较大时,xtxt 已经比较接近纯噪声了,噪声用欧式距离无妨,所以可以稍微增大 xt1xt1xtxt 的差距,即可以用较小的 αtαt。那么可不可以一直用较大的 αtαt 呢?可以是可以,但是要增大 TT

我们之前说过,应该有 ˉαT0¯αT0,我们利用 αtαt 的表达式来计算 ˉαT¯αT

logˉαT=Tt=1logαt=12Tt=1log(10.02tT)<12Tt=1(0.02tT)=0.005(T+1)log¯αT=Tt=1logαt=12Tt=1log(10.02tT)<12Tt=1(0.02tT)=0.005(T+1)(23)

由此可以看出 TT 要足够大,才能达到 00 的标准。当 T=1000T=1000 时,ˉαTe5¯αTe5

最后我们留意到,“建楼”模型中的 ϵθ(ˉαtx0+ˉβtε,t)ϵθ(¯αtx0+¯βtε,t) 中,我们在输入中显式地写出了 tt,这是因为原则上不同的 tt 处理的是不同层次的对象,所以应该用不同的重构模型,即应该有 TT 个不同的重构模型才对,于是我们共享了所有重构模型的参数,将 tt 作为条件传入。按照论文附录的说法,tt 是转换成位置编码后,直接加到残差模块上去的。

VAE角度

多步突破

在传统VAE中,编码过程和生成过程都是一步到位的:

编码:xz,生成:zx:xz,:zx(24)

DDPM将编码过程和生成过程分解为 TT 步:

编码:x=x0x1x2xT1xT=z生成:z=xTxT1xT2x1x0=x:x=x0x1x2xT1xT=z:z=xTxT1xT2x1x0=x(25)

联合散度

每一步编码过程被建模成 q(xt|xt1)q(xt|xt1),每一步生成过程被建模成 p(xt1|xt)p(xt1|xt),相应的联合分布就是:

q(x0,x1,x2,,xT)=q(xT|xT1)q(x2|x1)q(x1|x0)˜q(x0)p(x0,x1,x2,,xT)=p(x0|x1)p(xT2|xT1)p(xT1|xT)p(xT)q(x0,x1,x2,,xT)=q(xT|xT1)q(x2|x1)q(x1|x0)~q(x0)p(x0,x1,x2,,xT)=p(x0|x1)p(xT2|xT1)p(xT1|xT)p(xT)(26)

其中 x0x0 代表真实样本,所以 ˜q(x0)~q(x0) 就是数据分布;而 xTxT 代表着最终的编码,所以 p(xT)p(xT) 就是先验分布;剩下的 q(xt|xt1)q(xt|xt1)p(xt1|xt)p(xt1|xt) 就代表着编码、生成的一小步。

VAE可以理解为在最小化联合分布的KL散度,对于DDPM也是如此,上面我们已经写出了两个联合分布,所以DDPM的目的就是最小化

KL(qp)=q(xT|xT1)q(x1|x0)˜q(x0)logq(xT|xT1)q(x1|x0)˜q(x0)p(x0|x1)p(xT1|xT)p(xT)dx0dx1dxTKL(qp)=q(xT|xT1)q(x1|x0)~q(x0)logq(xT|xT1)q(x1|x0)~q(x0)p(x0|x1)p(xT1|xT)p(xT)dx0dx1dxT(27)

接下来,我们要将 q(xt|xt1)q(xt|xt1)p(xt1|xt)p(xt1|xt) 的具体形式定下来,然后简化DDPM的优化目标。

分而治之

DDPM将每一步的编码建立为正态分布:q(xt|xt1)=N(xt;αtxt1,β2tI)q(xt|xt1)=N(xt;αtxt1,β2tI),其主要的特点是均值向量仅仅由输入 xt1xt1 乘以一个标量 αtαt 得到,相比之下传统VAE的均值方差都是用神经网络学习出来的,因此DDPM是放弃了模型的编码能力,最终只得到一个纯粹的生成模型;至于 p(xt1|xt)p(xt1|xt),则被建模成均值向量可学习的正态分布 N(xt1;μ(xt),σ2tI)N(xt1;μ(xt),σ2tI)。其中 αt,βt,σtαt,βt,σt 都不是可训练参数,而是事先设定好的值,整个模型拥有可训练参数的就只有μ(xt)μ(xt)

由于目前分布 qq 不含任何可训练参数,因此目标 (27)(27) 中关于 qq 的积分就只是贡献一个可以忽略的常数,目标 (27)(27) 等价于:

q(xT|xT1)q(x1|x0)˜q(x0)logp(x0|x1)p(xT1|xT)p(xT)dx0dx1dxT=q(xT|xT1)q(x1|x0)˜q(x0)[logp(xT)+Tt=1logp(xt1|xt)]dx0dx1dxTq(xT|xT1)q(x1|x0)~q(x0)logp(x0|x1)p(xT1|xT)p(xT)dx0dx1dxT=q(xT|xT1)q(x1|x0)~q(x0)[logp(xT)+Tt=1logp(xt1|xt)]dx0dx1dxT(28)

由于先验分布 p(xT)p(xT) 一般都取标准正态分布,也是没有参数的,所以这一项也只是贡献一个常数。因此需要计算的就是每一项

q(xT|xT1)q(x1|x0)˜q(x0)logp(xt1|xt)dx0dx1dxT=q(xt|xt1)q(x1|x0)˜q(x0)logp(xt1|xt)dx0dx1dxt=q(xt|xt1)q(xt1|x0)˜q(x0)logp(xt1|xt)dx0dxt1dxtq(xT|xT1)q(x1|x0)~q(x0)logp(xt1|xt)dx0dx1dxT=q(xt|xt1)q(x1|x0)~q(x0)logp(xt1|xt)dx0dx1dxt=q(xt|xt1)q(xt1|x0)~q(x0)logp(xt1|xt)dx0dxt1dxt(29)

这两个等号分别是因为:

q(xT|xT1)q(xt+1|xt)dxt+1dxT=1q(xT|xT1)q(xt+1|xt)dxt+1dxT=1(30)

q(xt1|xt2)q(x1|x0)dx1dxt2=q(xt1|x0)q(xt1|xt2)q(x1|x0)dx1dxt2=q(xt1|x0)(31)

场景再现

除去优化无关的常数:

logp(xt1|xt)=log[12πσtext1μ(xt)22σ2t]=12σ2txt1μ(xt)2+常数logp(xt1|xt)=log12πσtext1μ(xt)22σ2t=12σ2txt1μ(xt)2+(32)

这和第一种推导过程中的 (7)(7) 式是一样的,同样的处理方法可以得到与 (19)(19) 式相同的目标函数:

EεN(0,I),x0˜p(x0)[εˉβtβtϵθ(ˉαtx0+ˉβtε,t)2]EεN(0,I),x0~p(x0)[ε¯βtβtϵθ(¯αtx0+¯βtε,t)2](33)

当然这里的系数已经去掉了(原论文中通过实验发现,去掉这个系数后的实际效果更好些)。

超参设置

对于 q(xt|xt1)q(xt|xt1) 来说,习惯上约定 α2t+β2t=1α2t+β2t=1。在第一种推导 (6)(6) 式中已经证明,由于正态分布的叠加性,在此约束之下我们有 xt=ˉαtx0+ˉβtˉεt,ˉεtN(0,I)xt=¯αtx0+¯βt¯εt,¯εtN(0,I),对应在这里的表示就是:

q(xt|x0)=q(xt|xt1)q(x1|x0)dx1dxt1=N(xt;ˉαtx0,ˉβ2tI)q(xt|x0)=q(xt|xt1)q(x1|x0)dx1dxt1=N(xt;¯αtx0,¯β2tI)(34)

另一方面,p(xT)p(xT) 一般都取标准正态分布 N(xT;0,I)N(xT;0,I)。而我们的学习目标是最小化两个联合分布的KL散度,即希望 p=qp=q,那么它们的边缘分布自然也相等,所以我们也希望

p(xT)=q(xT|xT1)q(x1|x0)˜q(x0)dx0dx1dxT1=q(xT|x0)˜q(x0)dx0p(xT)=q(xT|xT1)q(x1|x0)~q(x0)dx0dx1dxT1=q(xT|x0)~q(x0)dx0(35)

由于数据分布 ˜q(x0)~q(x0) 是任意的,所以要使上式恒成立,只能让 q(xT|x0)=p(xT)q(xT|x0)=p(xT),即退化为与 x0x0 无关的标准正态分布,这意味着我们要设计适当的 αtαt,使得 ˉαT0¯αT0。同时这再次告诉我们,DDPM是没有编码能力了,最终的 p(xT|x0)p(xT|x0) 可以说跟输入 x0x0 无关的,生成出来的图像也无法回到原图。

至于 σtσt,理论上不同的数据分布 ˜q(x0)~q(x0) 来说对应不同的最优 σtσt,但我们又不想将 σtσt 设为可训练参数,所以只好选一些特殊的 ˜q(x0)~q(x0) 来推导相应的最优 σtσt,并认为由特例推导出来的 σtσt 可以泛化到一般的数据分布。我们可以考虑两个简单的例子:

  1. 假设训练集只有一个样本xx,即 ˜q(x0)~q(x0) 是狄拉克分布 δ(x0x)δ(x0x),可以推出最优的 σt=ˉβt1ˉβtβtσt=¯βt1¯βtβt
  2. 假设数据分布 ˜q(x0)~q(x0) 服从标准正态分布,这时候可以推出最优的 σt=βtσt=βt

具体的推导在下面一节的“遗留问题”部分给出。实验结果显示两个选择的表现是相似的,因此可以选择任意一个进行采样。

贝叶斯角度

请贝叶斯

根据贝叶斯定理:

p(xt1|xt)=p(xt|xt1)p(xt1)p(xt)p(xt1|xt)=p(xt|xt1)p(xt1)p(xt)(36)

然而,我们并不知道 p(xt1),p(xt)p(xt1),p(xt) 的表达式,所以此路不通。但我们可以退而求其次,在给定 x0x0 的条件下使用贝叶斯定理:

p(xt1|xt,x0)=p(xt|xt1)p(xt1|x0)p(xt|x0)p(xt1|xt,x0)=p(xt|xt1)p(xt1|x0)p(xt|x0)(37)

这样修改是因为 p(xt|xt1),p(xt1|x0),p(xt|x0)p(xt|xt1),p(xt1|x0),p(xt|x0) 都是已知的(再次复习一下 (2)(6)(2)(6) 两式):

p(xt|xt1)=N(xt;αtxt1,β2tI)p(xt|x0)=N(xt;ˉαtx0,ˉβ2tI)p(xt1|x0)=N(xt1;ˉαt1x0,ˉβ2t1I)p(xt|xt1)=N(xt;αtxt1,β2tI)p(xt|x0)=N(xt;¯αtx0,¯β2tI)p(xt1|x0)=N(xt1;¯αt1x0,¯β2t1I)(38)

所以上式是可计算的,代入各自的表达式得到:

p(xt1|xt,x0)=12πβtextαtxt122β2t12πˉβt1ext1ˉαt1x022ˉβ2t112πˉβtextˉαtx022ˉβ2tp(xt1|xt,x0)=12πβtextαtxt122β2t12π¯βt1ext1¯αt1x022¯β2t112π¯βtext¯αtx022¯β2t(39)

系数部分 ˉβt2πˉβt1βt¯βt2π¯βt1βt 可知协方差矩阵是 ˉβ2t1β2tˉβ2tI¯β2t1β2t¯β2tI,在此基础上化简指数部分求出均值 ˜μ(xt,x0)~μ(xt,x0),可以得到:

p(xt1|xt,x0)=N(xt1;αtˉβ2t1ˉβ2txt+ˉαt1β2tˉβ2tx0,ˉβ2t1β2tˉβ2tI)p(xt1|xt,x0)=N(xt1;αt¯β2t1¯β2txt+¯αt1β2t¯β2tx0,¯β2t1β2t¯β2tI)(40)

去噪过程

现在我们得到了 p(xt1|xt,x0),它有显式的解,但并非我们想要的最终答案,因为我们只想通过 xt 来预测 xt1,而不能依赖 x0x0 是我们最终想要生成的结果。采取的解决方案是通过 xt 来预测 x0,从而消去 p(xt1|xt,x0) 中的 x0,使得它只依赖于 xt

我们用 ˉμ(xt) 来预估 x0,损失函数为 x0ˉμ(xt)2。训练完成后,我们就认为

p(xt1|xt)p(xt1|xt,x0=ˉμ(xt))=N(xt1;αtˉβ2t1ˉβ2txt+ˉαt1β2tˉβ2tˉμ(xt),ˉβ2t1β2tˉβ2tI)

x0ˉμ(xt)2 中,x0 代表原始数据,xt 代表带噪数据,所以这实际上在训练一个去噪模型,这也就是DDPM的第一个“D”的含义(Denoising)。具体来说,p(xt|x0)=N(xt;ˉαtx0,ˉβ2tI) 意味着 xt=ˉαtx0+ˉβtε,εN(0,I),或者写成x0=1ˉαt(xtˉβtε),这启发我们将 ˉμ(xt)参数化为

ˉμ(xt)=1ˉαt(xtˉβtϵθ(xt,t))

此时损失函数变为

x0ˉμ(xt)2=ˉβ2tˉα2tεϵθ(ˉαtx0+ˉβtε,t)2

省去前面的系数,就得到DDPM原论文所用的损失函数了。可以发现,本文是直接得出了从 xtx0 的去噪过程,而不是像之前两篇文章那样,通过 xtxt1 的去噪过程再加上积分变换来推导,相比之下本文的推导可谓更加一步到位了。

另一边,我们将式 (42) 代入到式 (41) 中,均值部分:

αtˉβ2t1ˉβ2txt+ˉαt1β2tˉβ2tˉμ(xt)=αtˉβ2t1ˉβ2txt+ˉαt1β2tˉβ2t1ˉαt(xtˉβtϵθ(xt,t))=αtˉβ2t1ˉβ2txt+β2tˉβ2tαt(xtˉβtϵθ(xt,t))=α2tˉβ2t1+β2tˉβ2tαtxtβ2tˉβtαtϵθ(xt,t)=1αtxtβ2tˉβtαtϵθ(xt,t)

最后一步是因为 α2tˉβ2t1+β2tˉβ2tαt=α2t(1ˉα2t1)+β2tˉβ2tαt=1ˉα2tˉβ2tαt=1αt,化简得到:

p(xt1|xt)p(xt1|xt,x0=ˉμ(xt))=N(xt1;1αt(xtβ2tˉβtϵθ(xt,t)),ˉβ2t1β2tˉβ2tI)

这就是反向的采样过程所用的分布,连同采样过程所用的方差也一并确定下来了,即上一节中由狄拉克分布推导出来的 σt=ˉβt1ˉβtβt。这里的 ϵθ 和前两种推导方式不同,反而跟DDPM原论文一致。

预估修正

不知道读者有没有留意到一个有趣的地方:我们要做的事情,就是想将 xT 慢慢地变为 x0,而我们在借用 p(xt1|xt,x0) 近似 p(xt1|xt) 时,却包含了“用 ˉμ(xt) 来预估 x0这一步,要是能预估准的话,那就直接一步到位了,还需要逐步采样吗?

真实情况是,“用 ˉμ(xt) 来预估 x0”当然不会太准的,至少开始的相当多步内不会太准。它仅仅起到了一个前瞻性的预估作用,然后我们只用 p(xt1|xt) 来推进一小步,这就是很多数值算法中的“预估-修正”思想,即我们用一个粗糙的解往前推很多步,然后利用这个粗糙的结果将最终结果推进一小步,以此来逐步获得更为精细的解。

由此我们还可以联想到Hinton在2019年提出的《Lookahead Optimizer: k steps forward, 1 step back》,它同样也包含了预估(k steps forward)和修正(1 step back)两部分,原论文将其诠释为“快(Fast)-慢(Slow)”权重的相互结合,快权重就是预估得到的结果,慢权重则是基于预估所做的修正结果。如果愿意,我们也可以用同样的方式去诠释DDPM的“预估-修正”过程~

遗留问题

(36) 式没法直接算的原因是 p(xt1),p(xt) 未知。根据定义:

p(xt)=p(xt|x0)˜p(x0)dx0

其中 p(xt|x0) 是知道的,而数据分布 ˜p(x0) 无法提前预知,所以不能进行计算。不过,有两个特殊的例子,是可以直接将两者算出来的:

  1. 整个数据集只有一个样本,不失一般性,假设该样本为 0,此时 ˜p(x0) 为狄拉克分布 δ(x0),可以直接算出 p(xt)=p(xt|0)。继而代入 (39) 式,可以发现结果正好是 p(xt1|xt,x0)x0=0 的特例,直接利用 (40) 式的结论,即

    p(xt1|xt)=p(xt1|xt,x0=0)=N(xt1;αtˉβ2t1ˉβ2txt,ˉβ2t1β2tˉβ2tI)

    我们主要关心其方差为 ˉβ2t1β2tˉβ2t,这便是采样方差的选择之一。

  2. 数据集服从标准正态分布,即 ˜p(x0)=N(x0;0,I)。前面我们说了 p(xt|x0)=N(xt;ˉαtx0,ˉβ2tI) 意味着 xt=ˉαtx0+ˉβtε,εN(0,I),而此时根据假设还有 x0N(0,I),所以由正态分布的叠加性,xt 正好也服从标准正态分布。将标准正态分布的概率密度代入 (36) 式:

    p(xt1|xt)=12πβtextαtxt122β2t12πext12212πext22=12πβtext1αtxt22β2t=N(xt1;αtxt,β2tI)

    我们同样主要关心其方差为 β2t,这便是采样方差的另一个选择。

参考文献

苏剑林. (Jun. 13, 2022). 《生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9119

苏剑林. (Jul. 06, 2022). 《生成扩散模型漫谈(二):DDPM = 自回归式VAE 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9152

苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9164

posted @   Be(CN₃H₃)₂  阅读(1057)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示