从DDPM到DDIM (一) 极大似然估计与证据下界
从DDPM到DDIM (一) 极大似然估计与证据下界
现在网络上关于DDPM和DDIM的讲解有很多,但无论什么样的讲解,都不如自己推到一遍来的痛快。笔者希望就这篇文章,从头到尾对扩散模型做一次完整的推导。本文的很多部分都参考了 Calvin Luo[1] 和 Stanley Chan[2] 写的经典教程。也推荐大家取阅读学习。
DDPM[3]是一个双向马尔可夫模型,其分为扩散过程和采样过程。
扩散过程是对于图片不断加噪的过程,每一步添加少量的高斯噪声,直到图像完全变为纯高斯噪声。为什么逐步添加小的高斯噪声,而不是一步到位,直接添加很强的噪声呢?这一点我们留到之后来探讨。
采样过程则相反,是对纯高斯噪声图像不断去噪,逐步恢复原始图像的过程。
下图展示了DDPM原文中的马尔可夫模型。
其中
扩散模型首先需要大量的图片进行训练,训练的目标就是估计图像的概率分布。训练完毕后,生成图像的过程就是在计算出的概率分布中采样。因此生成模型一般都有训练算法和采样算法,VAE、GAN、diffusion,还有如今大火的大预言模型(LLM)都不例外。本文讨论的DDPM和DDIM在训练方法上是一样的,只是DDIM在采样方法上与前者有所不同[4]。
估计生成样本的概率分布的最经典的方法就是极大似然估计,我们从极大似然估计开始。
1、从极大似然估计开始
首先简单回顾一下概率论中的一些基本概念,边缘概率密度、联合概率密度、概率乘法公式和马尔可夫链,最后回顾一个强大的数学工具:Jenson 不等式。对这些熟悉的同学可以不需要看1.1节。
1.1、概念回顾
边缘概率密度和联合概率密度: 大家可能还记得概率论中的边缘概率密度,忘了也不要紧,我们简单回顾一下。对于二维随机变量
概率乘法公式: 对于联合概率
概率乘法公式可以用条件概率的定义和数学归纳法证明。
马尔可夫链定义: 随机过程
其中
Jenson 不等式。Jenson 不等式有多种形式,我们这里采用其积分形式:
若
则有:
更进一步地,若
关于 Jenson 不等式的证明,用凸函数的定义证明即可。网上有很多,这里不再赘述。
1.2、概率分布表示
生成模型的主要目标是估计需要生成的数据的概率分布。这里就是
这里
显然,这个积分很不好求。Sohl-Dickstein等人在2015年的扩散模型的开山之作[5]中,采用的是这个方法:
Sohl-Dickstein等人借鉴的是统计物理中的技巧:退火重要性采样(annealed importance sampling) 和 Jarzynski equality。这两个就涉及到笔者的知识盲区了,感兴趣的同学可以自行找相关资料学习。(果然数学物理基础不牢就搞不好科研~)。
这里有的同学可能会有疑问,为什么用分子分母都为
这里自然就引出了问题,这么一堆随机变量的联合概率密度,我们还是不知道啊,
利用概率乘法公式,有:
我们这里是单独把
(3)式这样表示明显不如(2)式,因为我们最初就是要求
因为扩散模型是马尔可夫链,某一时刻的随机变量只和前一个时刻有关,所以:
于是有:
文章一开始说到,在扩散模型的采样过程中,单步转移概率是不知道的,需要用神经网络来拟合,所以我们给采样过程的单步转移概率都加一个下标
类似地,我们来计算
于是得到了以
数学推导的一个很重要的事情就是分清楚哪些是已知量,哪些是未知量。
1.3、极大似然估计
既然我们知道了
计算(1)式的下界一般有两种办法。分别是Jenson不等式法和KL散度方法。下面我们分别给出两种方法的推导。
Jenson不等式法。在进行极大似然估计的时候,一般会对概率分布取对数,于是我们对(1)式取对数可得:
KL散度方法。当然,我们也可以不采用Jenson不等式,利用KL散度的非负性,同样也可以得出证据下界。将证据下界中的数学期望展开,写为积分形式为:
另外,我们定义一个KL散度:
下面我们将验证:
具体地,有:
因此,(6)式成立。由于
个人还是更喜欢Jenson不等式法,因为此方法的思路一气呵成;而KL散度法像是先知道最终的答案,然后取验证答案的正确性。而且KL的非负性也是可以用Jenson不等式证明的,所以二者在数学原理上本质是一样的。KL散度法有一个优势,就是能让我们知道
下面,我们的方向就是逐步简化证据下界,直到简化为我们编程可实现的形式。
2、简化证据下界
对证据下界的化简,需要用到三个我们之前推导出来的表达式。为了方便阅读,我们把(4)式,(5)式,还有证据下界重写到这里。
下面我们将(4)式和(5)式代入到(7)式中,有:
蓝色部分的操作,是将分子分母的表示的随机变量概率保持一致,对同一个随机变量的概率分布的描述才具备可比性。而且,我们希望分子分母的连乘号下标保持一致,这样才能进一步化简。下面我们继续:
上式第一项,第二项,第三项分别被称为 重建项(Reconstruction Term)、先验匹配项(Prior Matching Term)、一致项(Consistency Term)。
- 重建项。顾名思义,这是对最后构图的预测概率。给定预测的最后一个隐变量
,预测生成的图像 的对数概率。 - 先验匹配项。 这一项描述的是扩散过程的最后一步生成的高斯噪声与纯高斯噪声的相似度,因为这一项并没有神经网络参数,所以不需要优化,后续网络训练的时候可以将这一项舍去。
- 一致项。这一项描述的是采样过程的单步转移概率
和扩散过程的单步转移概率 的距离。由于 是服从高斯分布的(加噪过程自己定义的),所以我们希望采样过程的单步转移概率 也服从高斯分布,这样才能使得二者的KL散度更加接近。我们之后会看到,最小化二者的KL散度等价于最大似然估计。
到这里我们通过观察可以发现,乘上
可能有的同学搞不清楚积分微元是哪个变量,如果不知道的话就把所有的随机变量都算为积分微元。如果哪个微元不需要的话,是可以被积分积掉的。注意到,
类似地,我们看 先验匹配项 和 一致项。先验匹配项中除了
一致项也用类似的操作:
下面我们继续化简,化简乘KL散度的形式。因为两个高斯分布的KL散度可以写成二范数Loss的形式,这是我们编程可实现的。我们先给出KL散度的定义。设两个概率分布和
注意,KL散度没有对称性,即
下面,我们以 一致项 中的其中一项为例子,来写成KL散度形式:
上式红色的部分,是参考的开篇提到的那两个教程。笔者自己推了一下,并没有得出相应的结果。
如果有好的解释,欢迎讨论。不过这里是否严格并不重要,之后我们会解释,事实上我们使用的是另外一种推导方式。
类似地,先验匹配项 也可以用类似的方法表示成KL散度的形式:
这里红色的部分,我们可以详细验证一下:
没有什么问题。
下面我们整理一下结果。我们简化的证据下界为:
我们看第三项,也就是 一致项。我们发现两个概率分布
如何优化证据下界呢。我们放到下一篇文章中来讲:
从DDPM到DDIM (二) 前向过程与反向过程的概率分布。
Luo C. Understanding diffusion models: A unified perspective[J]. arXiv preprint arXiv:2208.11970, 2022. ↩︎
Chan S H. Tutorial on Diffusion Models for Imaging and Vision[J]. arXiv preprint arXiv:2403.18103, 2024. ↩︎
Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models[J]. Advances in neural information processing systems, 2020, 33: 6840-6851. ↩︎
Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020. ↩︎
Sohl-Dickstein J, Weiss E, Maheswaranathan N, et al. Deep unsupervised learning using nonequilibrium thermodynamics[C]//International conference on machine learning. PMLR, 2015: 2256-2265. ↩︎
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程