从DDPM到DDIM(三) DDPM的训练与推理

从DDPM到DDIM(三) DDPM的训练与推理

前情回顾

首先还是回顾一下之前讨论的成果。

扩散模型的结构和各个概率模型的意义。下图展示了DDPM的双向马尔可夫模型。
img

其中xT代表纯高斯噪声,xt,0<t<T 代表中间的隐变量, x0 代表生成的图像。

  • q(xt|xt1) 加噪过程的单步转移概率,服从高斯分布,这很好理解。
  • q(xt1|xt) 是真正的采样过程的单步转移概率,但是求解它比较困难。
  • pθ(xt1|xt) 代表的是神经网络拟合的概率,我们希望神经网络能更好地拟合采样过程的单步转移概率。
  • q(xt1|xt,x0),给定最终生成结果 x0 的条件下,生成过程的单步转移概率。x0 就像有监督学习中的标签,指导着生成的方向。我们采用此概率来替代 q(xt1|xt) 做神经网络的拟合。如果无法理解,把它当作一个无物理意义的数学上的中间变量即可。

pθ(xt1|xt) 来表示。之所以这里增加一个 θ 下标,是因为 pθ(xt1|xt) 是用神经网络来逼近的转移概率, θ 代表神经网络参数。

联合概率表示 扩散模型的联合概率和前向条件联合概率为:

(1)p(x0:T)=p(xT)t=1Tpθ(xt1|xt)

(2)q(x1:T|x0)=t=1Tq(xt|xt1)

概率分布的具体表达式 之前提到的各种条件概率的具体表达式为:

(3)q(xt|xt1)=N(xt;αtxt1,(1αt)I)q(xt|x0)=N(xt;αtx0,(1αt)I)q(xt1|xt,x0)=N(xt1;μ~t(xt,x0),Σ~(t))

其中

μ~t(xt,x0)=(1αt1)αt(1αt)xt+(1αt)αt1(1αt)x0Σ~(t)=(1αt)(1αt1)1αtI=σ2(t)I

另外,p(xT) 服从标准高斯分布,pθ(xt1|xt) 是我们要训练的神经网络。

根据贝叶斯公式,我们要改造的条件概率如下:

(4)q(xt|xt1,x0)=q(xt1|xt,x0)q(xt|x0)q(xt1|x0)

证据下界 我们原本要对生成的图像分布进行极大似然估计,但直接估计无法计算。于是我们改为最大化证据下界,然后对证据下界进行化简,现在,我们采用 q(xt1|xt,x0) 重新优化证据下界:

(5)logp(x0)L=Eq(x1:T|x0)[logp(xT)t=1Tpθ(xt1|xt)t=1Tq(xt|xt1)]

3.5、利用 q(xt1|xt,x0) 重新推导证据下界

  书接上回。我们化简证据下界的一个想法是,我们希望将 pθ(xt1|xt)q(xt|xt1) 的每一项一一对齐;并且将含有 (x0,x1) 的项与其他项分开来。因为 x0 是图像,而其他随机变量是隐变量。还有一种解释是,这次我们采用了 q(xt1|xt,x0),而当 t=1 时,q(x0|x1,x0) 看起来好像是无意义的。所以我们要将含有 (x0,x1) 的项与其他项分开。

L=Eq(x1:T|x0)[logp(xT)t=1Tpθ(xt1|xt)t=1Tq(xt|xt1)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)t=2Tpθ(xt1|xt)q(x1|x0)t=2Tq(xt|xt1)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)q(x1|x0)]+Eq(x1:T|x0)[logt=2Tpθ(xt1|xt)t=2Tq(xt|xt1,x0)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)q(x1|x0)]+Eq(x1:T|x0)[logt=2Tpθ(xt1|xt)t=2Tq(xt1|xt,x0)q(xt|x0)q(xt1|x0)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)q(x1|x0)]+Eq(x1:T|x0)[logt=2Tq(xt1|x0)q(xt|x0)]+Eq(x1:T|x0)[logt=2Tpθ(xt1|xt)t=2Tq(xt1|xt,x0)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)q(x1|x0)]+Eq(x1:T|x0)[logq(x1|x0)q(xT|x0)]+Eq(x1:T|x0)[logt=2Tpθ(xt1|xt)t=2Tq(xt1|xt,x0)]=Eq(x1:T|x0)[logp(xT)pθ(x0|x1)q(xT|x0)]+t=2TEq(x1:T|x0)[logpθ(xt1|xt)q(xt1|xt,x0)]=Eq(x1:T|x0)[logpθ(x0|x1)]+Eq(x1:T|x0)[logp(xT)q(xT|x0)]+t=2TEq(x1:T|x0)[logpθ(xt1|xt)q(xt1|xt,x0)]

与之前一样,上式三项也分别代表三部分:重建项先验匹配项一致项

  • 重建项。顾名思义,这是对最后构图的预测概率。给定预测的最后一个隐变量 x1,预测生成的图像 x0 的对数概率。
  • 先验匹配项。 这一项描述的是扩散过程的最后一步生成的高斯噪声与纯高斯噪声的相似度,与之前相比,这一项的 q 部分的条件改为了 x0。同样,这一项并没有神经网络参数,所以不需要优化,后续网络训练的时候可以将这一项舍去。
  • 一致项。这一项与之前有两点不同。其一,与之前相比,不再有错位比较。其二,这匹配目标改为了由 pθ(xt|xt+1)q(xt1|xt,x0) 匹配,而之前是和扩散过程的单步转移概率 q(xt|xt1) 匹配。更加合理。

类似地,与之前的操作一样,我们将上式的数学期望下角标中的无关的随机变量约去(积分为1),然后转化成KL散度的形式。我们看 先验匹配项一致项

Eq(x1:T|x0)[logp(xT)q(xT|x0)]=q(x1:T|x0)logp(xT)q(xT|x0)dx1:T=q(xT|x0)logp(xT)q(xT|x0)dxT=DKL(q(xT|x0)p(xT))

Eq(x1:T|x0)[logpθ(xt1|xt)q(xt1|xt,x0)]=q(x1:T|x0)logpθ(xt1|xt)q(xt1|xt,x0)dx1:T=q(xt,xt1|x0)logpθ(xt1|xt)q(xt1|xt,x0)dxtdxt1=q(xt|x0)q(xt1|xt,x0)logpθ(xt1|xt)q(xt1|xt,x0)dxtdxt1=q(xt|x0)DKL(q(xt1|xt,x0)pθ(xt1|xt))dxt=Eq(xt|x0)[DKL(q(xt1|xt,x0)pθ(xt1|xt))]

重建项也类似,期望下角标的概率中,除了随机变量 x1 之外都可以约掉。最后,我们终于得出证据下界的KL散度形式:

(6)L=Eq(x1|x0)[logpθ(x0|x1)]DKL(q(xT|x0)p(xT))t=2TEq(xt|x0)[DKL(q(xt1|xt,x0)pθ(xt1|xt))]

  下面聊聊数学期望的下角标的物理意义是。以重建项为例,下角标为 q(x1|x0),代表用 x0 加噪一步生成 x1,然后用 x1 输入到神经网络中得到估计的 x0 的分布,然后最大化这个对数似然概率。而数学期望代表了多个图片,一个 epoch 之后取平均作为期望。一致项也类似,只是用 x0 生成 xt,然后通过神经网络计算与 q(xt1|xt,x0) 的KL散度。这实际上就是蒙特卡洛估计。

所以,我们需要计算loss的项有两个,一个是重建项中的对数部分,一个是一致项中的KL散度。至于数学期望和下角标,我们并不需要展开计算,而是在训练的时候用多个图片并分别添加不同程度的噪声来替代。

4、训练过程

  下面我们利用 (3) 式对证据下界 (6) 式做进一步展开。从DDPM到DDIM(二) 这篇文章讲过,在 βt 很小的前提下,pθ(xt1|xt) 也服从高斯分布。因为 pθ(xt1|xt) 的训练目标是匹配 q(xt1|xt,x0),我们也写成高斯分布的形式,并与 q(xt1|xt,x0) 的形式做对比。

q(xt1|xt,x0)=N(xt1;μ~t(xt,x0),Σ~(t))=12πσ(t)exp[12(xt1μ~t(xt,x0))TΣ~1(t)(xt1μ~t(xt,x0))]pθ(xt1|xt)=N(xt1;μ~θ(xt,t),Σ~(t))=12πσ(t)exp[12(xt1μ~θ(xt,t))TΣ~1(t)(xt1μ~θ(xt,t))]

这里 pθ(xt1|xt) 的均值 μ~θ(xt,t) 是神经网络输出,方差我们采用和 q(xt1|xt,x0) 一样的方差。神经网络 μ~θ(xt,t) 的输入有两个,第一个是 xt,这是显然的,还有一个输入时时刻 t,因为当然,方差也可以作为神经网络来训练,但是DDPM原文中做过实验,这样效果并不显著。因此,上述两个均值两个方差中,只有蓝色的 μ~θ(xt,t) 是未知的,另外三个量都是已知量。

根据 (6) 式,我们只需要计算 重建项一致项,先验匹配项没有训练参数。下面分别计算:

logpθ(x0|x1)=12σ2(t)x0μ~θ(x1,1)22+const

其中 const 代表某个常数。

  下面计算一致项,即KL散度。高斯分布的KL散度是有公式的,我们不加证明地给出,若需要证明,可以查阅维基百科。两个 d 维随机变量服从高斯分布 Q=N(μ1,Σ1) , P=N(μ2,Σ2),其中 μ1,μ2Rd,Σ1,Σ2Rd×d 二者的Kullback-Leibler 散度(KL散度)可以用以下公式计算:

DKL(QP)=12[logdetΣ2detΣ1d+tr(Σ21Σ1)+(μ2μ1)TΣ21(μ2μ1)]

下面我们将一致项代入上述公式:

(7)DKL(q(xt1|xt,x0)pθ(xt1|xt))=12[log1d+d+μ~θ(xt,t)μ~t(xt,x0)22/σ2(t)]=12σ2(t)μ~θ(xt,t)μ~t(xt,x0)22

从上两个式子可以看出,μ~θ(xt,t)t>0 的时候,目标是匹配 μ~t(xt,x0)。我们的研究哲学是,只要有解析形式,我们就将解析形式展开,直到某个变量没有解析解,这时候才会用神经网络拟合,这样可以最大化地保证拟合的效果。比如我们为了拟合一个二次函数 f(x)=ax2+3x+2,其中 a 是未知量,我们应该设计一个神经网络来估计 a,而不应该用神经网络来估计 f(x),因为前者确保了神经网络估计出来的函数是二次函数,而后者则有更多的不确定性。

  为了更好地匹配,我们展开 μ~t(xt,x0) 中的解析形式。

(8)μ~t(xt,x0)=(1αt1)αt(1αt)xt+(1αt)αt1(1αt)x0μ~θ(xt,t)=(1αt1)αt(1αt)xt+(1αt)αt1(1αt)x~θ(xt,t)

μ~θ(xt,t) 展开的形式与 μ~t(xt,x0) 相同。第一项是与 xt 相关的,因为 xt 是输入,所以保持不变,但是 x0 是未知量,所以我们还是用神经网络来替代,神经网络的输入同样也是 xtt。将 (8) 式代入 (7) 式,有:

DKL(q(xt1|xt,x0)pθ(xt1|xt))=12σ2(t)μ~θ(xt)μ~t(xt,x0)22=12σ2(t)(1αt)2αt1(1αt)2x0x~θ(xt,t)22

重建项也可以继续化简,注意到 β0=0,α0=1,α0=1,α1=α1

logpθ(x0|x1)=12σ2(t)x0μ~θ(x1,1)22+const=12σ2(t)x0(1α0)α1(1α1)xt+(1α1)α0(1α1)x~θ(x1,t)22+const=12σ2(t)x0x~θ(x1,t)22+const=12σ2(t)(1α1)2α0(1α1)2x0x~θ(xt,t)22+const

上式最后一行是为了与KL散度的形式保持一致。经过这么长时间的努力,我们终于将证据下界化为最简形式。我们把我们计算出的重建项和一致项代入到 (6) 式,并舍弃和神经网络参数无关的先验匹配项,有:

(9)L=t=1TEq(xt|x0)[12σ2(t)(1αt)2αt1(1αt)2x0x~θ(xt,t)22]

因为前面有个负号,所以最大化证据下界等价于最小化以下损失函数:

θ=argminθt=1T12σ2(t)(1αt)2αt1(1αt)2Eq(xtx0)[x~θ(xt,t)x022]

理解上式也很简单。首先我们看每一项的权重 12σ2(t)(1αt)2αt1(1αt)2,这表示了马尔可夫链每一个阶段预测损失的权重,DDPM论文的实验证明,忽略此权重影响不大,所以我们继续简化为:

θ=argminθt=1TEq(xtx0)[x~θ(xt,t)x022]

  其实现方式就是给你一张图像 x0,然后分别按照不同的步骤加噪,最多加到 T 步噪声,得到 x1,x2,...,xT 个隐变量。如下图所示,由于多步转移概率的性质,我们可以从 x0 一步加噪到任意一个噪声阶段。

img

  然后将这些隐变量分别送入神经网络,输出与 x0 计算二范数Loss,然后所有的Loss取平均。然而,实际实现的时候,我们不仅仅只有一张图,而是有很多张图。送入神经网络的时候也是以一个 batch 的形式处理的,如果每张图片都加这么多次噪声,那训练的工作量就会非常巨大。所以实际上我们采用这样的方式:假设一个batch中有 N 张图片,对于这 N 张图片分别添加不同阶段的高斯噪声,图像添加噪声的程度也是随机的,比如第一张图像加噪 10 步,第二张图像加噪 910 步,等等。然后分别输入加噪后的隐变量和时刻信息,神经网络的输出与每一张原始图像分别做二范数loss,最后平均。这样相比于只给一张图像加 1000 种不同的噪声的优势是防止在一张图像上过拟合,然后陷入局部极小。下面我们给出具体的训练算法流程:


Algorithm 1 . Training a Deniosing Diffusion Probabilistic Model. (Version: Predict image)

For every image x0 in your training dataset:

  • Repeat the following steps until convergence.
  • Pick a random time stamp tUniform[1,T].
  • Draw a sample xtq(xt|xt), i.e.

xt=αtx0+1αtϵ,ϵN(0,I)

  • Take gradient descent step on

θx~θ(xt,t)x022

You can do this in batches, just like how you train any other neural networks. Note that, here, you are training one denoising network x~θ for all noisy conditions.


采用batch来训练的话,就对每个图片分别同时进行上述操作,值得注意的是,神经网络参数只有一个,无论是哪一个 t 步去噪,其不同只有输入的不同,而神经网络只有 x~θ 一个。训练示意图如下:

img

  说句题外话,其实DDPM的原文很具有误导性,如下图的DDPM的原图。从这张图上看,或许有些同学以为神经网络是输入 xt 来预测 xt1,实际上并非如此。是输入 xt 来预测 x0,原因就是我们采用 q(xt1|xt,x0) 来作为拟合目标,目标是匹配其均值 μ~t(xt,x0),而不是匹配 xt1。而 μ~t(xt,x0) 恰好是 x0 的函数,所以我们在训练上的时候实际上是输入 xt 用神经网络来预测 x0。而采样过程才是一步一步采样的。正因为训练时候神经网络拟合的对象并不是 xt1,所以就给了我们在采样过程中的加速的空间,这就是后话了。

img

5、推理过程

  大家先别翻论文,你觉得最简单的一个生成图像的想法是什么。我当时就想过,既然神经网络 x~θ 是输入 xt 来预测 x0,那么我们直接给他一个随机噪声,一步生成图像不行吗?这个问题存疑,因为最新的研究确实有单步图像生成的,不过笔者还没有精读,就暂不评价。

  按照马尔可夫性质,还是用 pθ(xt1|xt) 一步一步做蒙特卡洛生成:

(9)xt1pθ(xt1|xt)=N(xt1;μ~θ(xt,t),σ2(t)I)xt1=μ~θ(xt,t)+σ2(t)ϵ=(1αt1)αt(1αt)xt+(1αt)αt1(1αt)x~θ(xt,t)+σ2(t)ϵ

其中 σ2(t)=(1αt)(1αt1)1αt

  扩散模型给我的感觉就是,训练过程和推理过程的差别很大。或许这就是生成模型,训练算法和推理算法的形式有很大的区别,包括文本的自回归生成也是如此。他不像图像分类,推理的时候跟训练时是一样的计算方式,只是最后来一个取概率最大的类别就行。训练过程和推理过程的极大差异决定了此推理形式不是唯一的形式,一定有更优的推理算法。

这个推理过程由如下算法描述。


Algorithm 2. Inference on a Deniosing Diffusion Probabilistic Model. (Version: Predict image)

Input: the trained model x~θ.

  • You give us a white noise vector xTN(0,I)
  • Repeat the following for t=T,T1,...,1.
  • Update according to

xt1=(1αt1)αt(1αt)xt+(1αt)αt1(1αt)x~θ(xt,t)+σ2(t)ϵ,ϵN(0,I)

Output: x0.


img

  • 推理输出的 x0 还需要进行去归一化和离散化到 0 到 255 之间,这个我们留到下一篇文章讲。
  • 另外,在DDPM原文中,并没有直接预测 x0,而是对 x0 进行了重参数化,让神经网络预测噪声 ϵ,这是怎么做的呢,我们也留到下一篇文章讲。

下一篇文章 《从DDPM到DDIM(四) 预测噪声与生图后处理》

posted @   txdt  阅读(606)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程
点击右上角即可分享
微信分享提示