生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪

到目前为止,笔者给出了生成扩散模型DDPM的两种推导,分别是《生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼》中的通俗类比方案和《生成扩散模型漫谈(二):DDPM = 自回归式VAE》中的变分自编码器方案。两种方案可谓各有特点,前者更为直白易懂,但无法做更多的理论延伸和定量理解,后者理论分析上更加完备一些,但稍显形式化,启发性不足。

 

贝叶斯定理(来自维基百科)

贝叶斯定理(来自维基百科)

 

在这篇文章中,我们再分享DDPM的一种推导,它主要利用到了贝叶斯定理来简化计算,整个过程的“推敲”味道颇浓,很有启发性。不仅如此,它还跟我们后面将要介绍的DDIM模型有着紧密的联系。

模型绘景 #

再次回顾,DDPM建模的是如下变换流程:
(1)x=x0x1x2xT1xT=z
其中,正向就是将样本数据x逐渐变为随机噪声z的过程,反向就是将随机噪声z逐渐变为样本数据x的过程,反向过程就是我们希望得到的“生成模型”。

正向过程很简单,每一步是
(2)xt=αtxt1+βtεt,εtN(0,I)
或者写成p(xt|xt1)=N(xt;αtxt1,βt2I)。在约束αt2+βt2=1之下,我们有
(3)xt=αtxt1+βtεt=αt(αt1xt2+βt1εt1)+βtεt==(αtα1)x0+(αtα2)β1ε1+(αtα3)β2ε2++αtβt1εt1+βtεtN(0,(1αt2α12)I)
从而可以求出p(xt|x0)=N(xt;α¯tx0,β¯t2I),其中α¯t=α1αt,而β¯t=1α¯t2

DDPM要做的事情,就是从上述信息中求出反向过程所需要的p(xt1|xt),这样我们就能实现从任意一个xT=z出发,逐步采样出xT1,xT2,,x1,最后得到随机生成的样本数据x0=x

请贝叶斯 #

下面我们请出伟大的贝叶斯定理。事实上,直接根据贝叶斯定理我们有
(4)p(xt1|xt)=p(xt|xt1)p(xt1)p(xt)
然而,我们并不知道p(xt1),p(xt)的表达式,所以此路不通。但我们可以退而求其次,在给定x0的条件下使用贝叶斯定理:
(5)p(xt1|xt,x0)=p(xt|xt1)p(xt1|x0)p(xt|x0)
这样修改自然是因为p(xt|xt1),p(xt1|x0),p(xt|x0)都是已知的,所以上式是可计算的,代入各自的表达式得到:
(6)p(xt1|xt,x0)=N(xt1;αtβ¯t12β¯t2xt+α¯t1βt2β¯t2x0,β¯t12βt2β¯t2I)

推导:上式的推导过程并不难,就是常规的展开整理而已,当然我们也可以找点技巧加快计算。首先,代入各自的表达式,可以发现指数部分除掉1/2因子外,结果是:
(7)xtαtxt12βt2+xt1α¯t1x02β¯t12xtα¯tx02β¯t2
它关于xt1是二次的,因此最终的分布必然也是正态分布,我们只需要求出其均值和协方差。不难看出,展开式中xt12项的系数是
(8)αt2βt2+1β¯t12=αt2β¯t12+βt2β¯t12βt2=αt2(1α¯t12)+βt2β¯t12βt2=1α¯t2β¯t12βt2=β¯t2β¯t12βt2
所以整理好的结果必然是β¯t2β¯t12βt2xt1μ~(xt,x0)2的形式,这意味着协方差矩阵是β¯t12βt2β¯t2I。另一边,把一次项系数拿出来是2(αtβt2xt+α¯t1β¯t12x0),除以2β¯t2β¯t12βt2后便可以得到
(9)μ~(xt,x0)=αtβ¯t12β¯t2xt+α¯t1βt2β¯t2x0
这就得到了p(xt1|xt,x0)的所有信息了,结果正是式(6)

去噪过程 #

现在我们得到了p(xt1|xt,x0),它有显式的解,但并非我们想要的最终答案,因为我们只想通过xt来预测xt1,而不能依赖x0x0是我们最终想要生成的结果。接下来,一个“异想天开”的想法是

如果我们能够通过xt来预测x0,那么不就可以消去p(xt1|xt,x0)中的x0,使得它只依赖于xt了吗?

说干就干,我们用μ¯(xt)来预估x0,损失函数为x0μ¯(xt)2。训练完成后,我们就认为
(10)p(xt1|xt)p(xt1|xt,x0=μ¯(xt))=N(xt1;αtβ¯t12β¯t2xt+α¯t1βt2β¯t2μ¯(xt),β¯t12βt2β¯t2I)
x0μ¯(xt)2中,x0代表原始数据,xt代表带噪数据,所以这实际上在训练一个去噪模型,这也就是DDPM的第一个“D”的含义(Denoising)。

具体来说,p(xt|x0)=N(xt;α¯tx0,β¯t2I)意味着xt=α¯tx0+β¯tε,εN(0,I),或者写成x0=1α¯t(xtβ¯tε),这启发我们将μ¯(xt)参数化为
(11)μ¯(xt)=1α¯t(xtβ¯tϵθ(xt,t))
此时损失函数变为
(12)x0μ¯(xt)2=β¯t2α¯t2εϵθ(α¯tx0+β¯tε,t)2
省去前面的系数,就得到DDPM原论文所用的损失函数了。可以发现,本文是直接得出了从xtx0的去噪过程,而不是像之前两篇文章那样,通过xtxt1的去噪过程再加上积分变换来推导,相比之下本文的推导可谓更加一步到位了。

另一边,我们将式(11)代入到式(10)中,化简得到
(13)p(xt1|xt)p(xt1|xt,x0=μ¯(xt))=N(xt1;1αt(xtβt2β¯tϵθ(xt,t)),β¯t12βt2β¯t2I)
这就是反向的采样过程所用的分布,连同采样过程所用的方差也一并确定下来了。至此,DDPM推导完毕~提示:出于推导的流畅性考虑,本文的ϵθ跟前两篇介绍不一样,反而跟DDPM原论文一致。)

推导:将式(11)代入到式(10)的主要化简难度就是计算
(14)αtβ¯t12β¯t2+α¯t1βt2α¯tβ¯t2=αtβ¯t12+βt2/αtβ¯t2=αt2(1α¯t12)+βt2αtβ¯t2=1α¯t2αtβ¯t2=1αt

预估修正 #

不知道读者有没有留意到一个有趣的地方:我们要做的事情,就是想将xT慢慢地变为x0,而我们在借用p(xt1|xt,x0)近似p(xt1|xt)时,却包含了“用μ¯(xt)来预估x0”这一步,要是能预估准的话,那就直接一步到位了,还需要逐步采样吗?

真实情况是,“用μ¯(xt)来预估x0”当然不会太准的,至少开始的相当多步内不会太准。它仅仅起到了一个前瞻性的预估作用,然后我们只用p(xt1|xt)来推进一小步,这就是很多数值算法中的“预估-修正”思想,即我们用一个粗糙的解往前推很多步,然后利用这个粗糙的结果将最终结果推进一小步,以此来逐步获得更为精细的解。

由此我们还可以联想到Hinton三年前提出的《Lookahead Optimizer: k steps forward, 1 step back》,它同样也包含了预估(k steps forward)和修正(1 step back)两部分,原论文将其诠释为“快(Fast)-慢(Slow)”权重的相互结合,快权重就是预估得到的结果,慢权重则是基于预估所做的修正结果。如果愿意,我们也可以用同样的方式去诠释DDPM的“预估-修正”过程~

遗留问题 #

最后,在使用贝叶斯定理一节中,我们说式(4)没法直接用的原因是p(xt1)p(xt)均不知道。因为根据定义,我们有
(15)p(xt)=p(xt|x0)p~(x0)dx0
其中p(xt|x0)是知道的,而数据分布p~(x0)无法提前预知,所以不能进行计算。不过,有两个特殊的例子,是可以直接将两者算出来的,这里我们也补充计算一下,其结果也正好是上一篇文章遗留的方差选取问题的答案。

第一个例子是整个数据集只有一个样本,不失一般性,假设该样本为0,此时p~(x0)为狄拉克分布δ(x0),可以直接算出p(xt)=p(xt|0)。继而代入式(4),可以发现结果正好是p(xt1|xt,x0)x0=0的特例,即
(16)p(xt1|xt)=p(xt1|xt,x0=0)=N(xt1;αtβ¯t12β¯t2xt,β¯t12βt2β¯t2I)
我们主要关心其方差为β¯t12βt2β¯t2,这便是采样方差的选择之一。

第二个例子是数据集服从标准正态分布,即p~(x0)=N(x0;0,I)。前面我们说了p(xt|x0)=N(xt;α¯tx0,β¯t2I)意味着xt=α¯tx0+β¯tε,εN(0,I),而此时根据假设还有x0N(0,I),所以由正态分布的叠加性,xt正好也服从标准正态分布。将标准正态分布的概率密度代入式(4)后,结果的指数部分除掉1/2因子外,结果是:
(17)xtαtxt12βt2+xt12xt2
跟推导p(xt1|xt,x0)的过程类似,可以得到上述指数对应于
(18)p(xt1|xt)=N(xt1;αtxt,βt2I)
我们同样主要关心其方差为βt2,这便是采样方差的另一个选择。

文章小结 #

本文分享了DDPM的一种颇有“推敲”味道的推导,它借助贝叶斯定理来直接推导反向的生成过程,相比之前的“拆楼-建楼”类比和变分推断理解更加一步到位。同时,它也更具启发性,跟接下来要介绍的DDIM有很密切的联系。

转载到请包括本文地址:https://spaces.ac.cn/archives/9164

更详细的转载事宜请参考:《科学空间FAQ》

posted @   jasonzhangxianrong  阅读(169)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· [翻译] 为什么 Tracebit 用 C# 开发
· Deepseek官网太卡,教你白嫖阿里云的Deepseek-R1满血版
· DeepSeek崛起:程序员“饭碗”被抢,还是职业进化新起点?
· 2分钟学会 DeepSeek API,竟然比官方更好用!
· .NET 使用 DeepSeek R1 开发智能 AI 客户端
点击右上角即可分享
微信分享提示