很多次翻看DDPM,始终不太能理解论文中提到的\(\text{Variational Inference}\)到底是如何在这个工作中起到作用。近期无意间又刷到徐亦达老师早些年录制的理论视频,没想到其中也有介绍这部分的内容。很推荐老师录制的视频课程,把每一步都讲解得很仔细。
Background
本文记录一下个人对开头问题的思考,即DDPM
是如何使用\(\text{Variational Inference}\)进行优化的?
整个系列有相对完整的公式推导,若正文中有涉及到的省略部分,皆额外整理在Part4,并会在正文中会指明具体位置。
对于生成任务来说,希望从给定数据中学习到的是数据的潜在信息。比如图片生成,在给定一些图片后,模型学习到的是“正常图片长什么样子”,如:
- 一张包含手机正面的图片会有【手机屏幕】;
- 一张包含猫咪的图片会有人们观察到的猫咪模样;
- ...
对于图片中每个像素点和附近的像素点,进行“合理”布局,才能生成“符合人们认知的图片”。
图片生成能像常见的机器学习任务如分类任务、回归任务,能基于maximize likelihood
的形式来训练么?
结论是很难,先回顾如何做maximum likelihood
。给定一批数据,首先需要假定数据服从的分布,接着写出似然函数,之后直接通过解析解的形式或是梯度下降的形式,求出分布。
问题就出在假定分布这一步,没有人知道图片客观上服从的分布\(p(x)\)。在深度学习时代,大家也常常会基于神经网络模型直接模拟数据分布,借助Softmax
的思路求解概率,但是底部的正则项依然Intractable
。(宋飏提出的Score based generative model
尝试解决这一问题)
基于此,我们尝试借助变分推断\(\text{(Variational Inference)}\)引入其它分布,将原本难以优化的问题变为可优化问题。
ELBO
首先,抛开上述提到的所有背景,单纯研究一下\(p(x)\),看看能得到什么有意思的结论。
a. 基于贝叶斯定理,引入新的随机变量\(z\):\(p(x) = \frac{p(x, z)}{p(z\mid x)}\);
b. 对于两边同时取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\);
c. 右边分子分母同乘以\(q(z)\):$$\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}$$
d. 再次,对于上式左右两边求关于\(q(z)\)的期望,等式依然成立:
\[\begin{equation}
\begin{aligned}
&\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\end{equation}
\]
一系列变换后,\((1)\)式是最后的推导结果,等式右边由两个项组成。第二个项\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用来衡量两个分布之间的“距离”,性质是值不小于0。
这样一来,通过\((1)\)可以得到不等式\((2)\):
\[\begin{equation}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation}
\]
\((1)\)式右边的第一项,同时也是\((2)\)式的右边项,被学者们叫做\(\text{ELBO(Evidence Lower Bound)}\)。
Objective Function
上述推导的\((2)\)式可以被视作“定理”一般的存在,即对于某个分布的对数形式,总可以找到它的下界。
那\((2)\)式可以用来做什么?在Background
中提到,图片生成任务中的\(p(x)\)想要对它做maximum likelihood
根本无法做起。目标依然是最大化\(p(x)\),但有了\((2)\)式,求解的目标可以转移到最大化它的下界\(\text{ELBO}\)。
沿着这个思路,尝试对DDPM
中reverse process
待优化的目标\(\log p_\theta\left(\mathbf{x}_0\right)\)进行改写。将\((2)\)式从上方复制下来,并将不等式右边写成期望的形式。
\[\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\]
由\((2)\)可得\((3)\)式:
\[\begin{equation}
\begin{aligned}
& \log{p_\theta(\mathbf{x}_0)} \geq \int_z q(z)\ln{\frac{p_\theta(\mathbf{x}_0, z)}{q(z)}}dz \\
\overset{\text{增加负号}}{\iff} & -\log{p_\theta(\mathbf{x}_0)} \leq -\int_z q(z)\log{\frac{p_\theta(\mathbf{x}_0, z)}{q(z)}}dz \\
\overset{定义z := \mathbf{x}_{1: T}}{\iff} & -\log{p_\theta(\mathbf{x}_0)} \leq -\int_{\mathbf{x}_{1:T}} q(x_{1:T})\log{\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T})}}d\mathbf{x}_{1: T} = \int_{\mathbf{x}_{1:T}} q(x_{1:T})\log{\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)}}d\mathbf{x}_{1: T} \\
\overset{期望}{\iff} & \mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation}
\]
其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要优化的最终目标即图像的分布,\(\theta\)是模型参数,\(\mathbf{x}_0\)是图片。
很明显,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)对应\((2)\)中引入的额外分布\(q(z)\),\(z\)是隐变量\((\text{latent})\)。对于diffusion models
,\(\mathbf{x}_0\)依次加噪后的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隐变量,记作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\);
\(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是关于\(\mathbf{x}_0, z\)的联合概率分布,因为选用马尔可夫链建模,那么依据马尔可夫链的性质,有:
\[\begin{equation}
\begin{aligned}
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\
p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation}
\]
将\((4)\)带入\((3)\)不等式右边的第一项,得到\(L\):
\[\begin{aligned}
&\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\
=&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\
=&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L
\end{aligned}
\]
到目前为止,经过多轮的变换得到了\(L\)。\(L\)是一个替代的优化目标:
\[\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}_0))} \iff \mathop{\arg\max}{(\ln{p}_{\theta}(\mathbf{x}_0))}
\]
接下来,论文中对\(L\)进行改写,摘录论文\(\text{Appendix A}\)如下所示:
\[\begin{equation}\begin{aligned} L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\
&=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation}\]
倒数两步的变换发生在第二项,具体依据为:
\[\begin{aligned}
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
\end{aligned}
\Rightarrow
\begin{aligned}
&\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
\]
接着对\((5)\)进行改写得到最终形式\((6)\):
\[\begin{equation}
\begin{aligned}
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
&=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]
\end{aligned}
\end{equation}\]
Summary
论文中定义马尔可夫链相邻状态的转变是服从高斯分布的,故\((6)\)式最起码是个可以优化的目标函数。而实际上,\((6)\)式还会进一步被改写,得到更精简的\(L_{simple}\)。这部分内容在下一部分Part3: Dive into DDPM中有更加详细的介绍。
DDPM
是应用\(\text{variational inference}\)进行优化求解的典型例子,很值得学习。
Reference