接着扩散模型 简述训练扩散模型过程中用到的损失函数形式。完整的观察数据\(x\)的对数似然如下:
\[\begin{aligned}
\mathrm{log}\ p(x)
&\geq \mathbb{E}_{q_{\phi}(z_{1:T}|z_0)} \mathrm{log} \frac{p(z_T)\prod_{t=0}^{T-1}p_{\theta}(z_t|z_{t+1})}{\prod_{t=0}^{T-1}q_{\phi}(z_{t+1}|z_t)} \\
&= \mathbb{E}_{q_{\phi}(z_{1}|z_0)} [\mathrm{log}\ p_{\theta}(z_0|z_1) ] - \mathbb{D}_{KL}(q_{\phi}(z_T|z_0)||p(z_T)) - \sum_{t=2}^{T} \mathbb{E}_{q_{\phi}(z_t|z_0)} [ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ]
\end{aligned}
\tag {1}
\]
其中,\(q_{\phi}(z_{t-1}|z_t,z_0)\)为了便于计算,已经近似为高斯分布
\[\mathcal N(\mu_q(z_t,z_0), \Sigma_q(t)) \tag {2}\]
\[\mu_q(z_t, z_0) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_0 }{ 1 - \bar {\alpha}_t^2 }
\tag {3}
\]
\[\Sigma_q(t) = \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }I
\tag {4}
\]
形式一
为了使得去噪过程\(p_{\theta}(z_{t-1}|z_t)\)和“真实”的\(q_{\phi}(z_{t-1}|z_t,z_0)\)尽可能接近,因此也可以将\(p_{\theta}(z_{t-1}|z_t)\)建模为一个高斯分布。又由于所有的\(\alpha\)项在每个时间步都是固定的,因此可以将其方差设计与“真实”的\(q(z_{t-1}|z_t,z_0)\)的方差是一样的。且这个高斯分布与初始值\(z_0\)是无关的,因此可以将其均值设计为关于\(z_t, t\)的函数,即设为\(\mu_{\theta}(z_t,t)\).
考虑两个高斯分布的KL散度等于
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(x;\mu_x,\Sigma_x) || \mathcal N(y;\mu_y,\Sigma_y)) \\
& = \frac{1}{2}[log\frac{|\Sigma_y|}{|\Sigma_x|} - d + tr(\Sigma_y^{-1}\Sigma_x) + (\mu_y-\mu_x)^T\Sigma_y^{-1}(\mu_y-\mu_x)]
\end{aligned}
\tag {5}
\]
应用到公式(1)中的第三项,因此有
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)}||\mu_{\theta}(x_t,t) - \mu_{q}(x_t,x_0)||^2
\end{aligned}
\tag {6}
\]
其中\(\sigma_{q}^2(t)\)是公式(4)前的系数即\(\sigma_{q}^2(t)= \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }\)
由于\(\mu_{\theta}(x_t,t)\)也是\(x_t\)的函数,因此,可以参考公式(3)的形式,将进一步假设
\[\mu_{\theta}(x_t, t) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_{\theta}(z_t, t) }{ 1 - \bar {\alpha}_t^2 }
\tag {7}
\]
这样公式(6)进一步化简为
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)} \frac{\bar{\alpha}_{t-1}^2( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)^2} ||z_{\theta}(z_t,t) - z_0||^2
\end{aligned}
\tag {8}
\]
至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出其原来的样本。最终最小化公式(1)中的第三项,等价于最小化关于时间步的期望,因此有
\[arg min \mathbb{E}_{t \sim U\{2,T\}} [ \mathbb{E}_{q_{\phi}(z_t|z_0)}[ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] ]
\]
形式二
由
\[z_t = \bar \alpha_t z_0 + \sqrt{1-\bar {\alpha}_t^2} \bar \epsilon_t
\tag {9}
\]
可得
\[z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}_t}{\bar {\alpha}_t}
\tag {10}
\]
再代入公式(3)得
\[\mu_q(x_t,x_0) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \bar \epsilon_t
\tag{11}
\]
参考形式一中的假设方式,可以假设
\[\mu_{\theta}(x_t,t) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \epsilon_{\theta}(z_t, t)
\tag{12}
\]
再代入公式(6)可以得到
\[\begin{aligned}
& \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\
& = \frac{1}{2\sigma_{q}^2(t)} \frac{( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)\alpha_t^2} ||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2
\end{aligned}
\tag {12}
\]
至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出按照公式(10)添加的原始噪音。
形式三
由公式(8)和公式(12)可以得到
\[||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} ||z_{\theta}(z_t,t) - z_0||^2
\tag{13}
\]
由于\(\bar {\alpha_t}, \sqrt{1-\bar {\alpha_t}^2}\) 分别是\(t\)时间步的加噪信号公式(9)中的原始信号和噪音信号系数,因此将信噪比SNR(t)定义为系数平方之比,即
\[SNR(t) = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2}
\tag {14}
\]
这个信噪比在时间步初期其值较大,代表真实信号占比多噪音占比少;在时间步后期其值较小,代表真实信号占比少噪音占比多。因为推理过程是完全从高斯分布随机取样,为了保证推理与训练保持一致,训练过程采取特定的\(\bar {\alpha}_t\)使得T步得到的是完全噪音,不包含任何原始信号。此时信噪比是0.
当预测发送在信噪比接近0(\(\bar \alpha_t \to 0\))时,模型原始预测是噪音\(\bar \epsilon\),因此根据公式(10)预估对应的原始信号
\[\bar z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}}{\bar {\alpha}_t}
\]
这样网络预测的微小差异就会被放大很多倍,因此在论文[3]模型蒸馏过程,这就不是一个稳定的设计。为了避免这个问题,作者提出了3种解决办法。
- 直接预测\(z\),而非噪音\(\epsilon\)
- 同时预测\(z, \epsilon\),通过两个独立的输出通道\(z, \epsilon\)。由于根据公式(10)可以再由\(\epsilon\)再推断出\(z^{'}\),然后可以根据\(\bar \alpha_t^2, 1-\bar \alpha_t^2\)对这两个值进行差值。
- 预测混合体 \(v=\alpha_t\epsilon - \sqrt{1-\alpha_t^2}z\)
参考
[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models