DDMP中的损失函数

接着扩散模型 简述训练扩散模型过程中用到的损失函数形式。完整的观察数据\(x\)的对数似然如下:

\[\begin{aligned} \mathrm{log}\ p(x) &\geq \mathbb{E}_{q_{\phi}(z_{1:T}|z_0)} \mathrm{log} \frac{p(z_T)\prod_{t=0}^{T-1}p_{\theta}(z_t|z_{t+1})}{\prod_{t=0}^{T-1}q_{\phi}(z_{t+1}|z_t)} \\ &= \mathbb{E}_{q_{\phi}(z_{1}|z_0)} [\mathrm{log}\ p_{\theta}(z_0|z_1) ] - \mathbb{D}_{KL}(q_{\phi}(z_T|z_0)||p(z_T)) - \sum_{t=2}^{T} \mathbb{E}_{q_{\phi}(z_t|z_0)} [ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] \end{aligned} \tag {1} \]

其中,\(q_{\phi}(z_{t-1}|z_t,z_0)\)为了便于计算,已经近似为高斯分布

\[\mathcal N(\mu_q(z_t,z_0), \Sigma_q(t)) \tag {2}\]

\[\mu_q(z_t, z_0) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_0 }{ 1 - \bar {\alpha}_t^2 } \tag {3} \]

\[\Sigma_q(t) = \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }I \tag {4} \]

形式一

为了使得去噪过程\(p_{\theta}(z_{t-1}|z_t)\)和“真实”的\(q_{\phi}(z_{t-1}|z_t,z_0)\)尽可能接近,因此也可以将\(p_{\theta}(z_{t-1}|z_t)\)建模为一个高斯分布。又由于所有的\(\alpha\)项在每个时间步都是固定的,因此可以将其方差设计与“真实”的\(q(z_{t-1}|z_t,z_0)\)的方差是一样的。且这个高斯分布与初始值\(z_0\)是无关的,因此可以将其均值设计为关于\(z_t, t\)的函数,即设为\(\mu_{\theta}(z_t,t)\).

  考虑两个高斯分布的KL散度等于

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(x;\mu_x,\Sigma_x) || \mathcal N(y;\mu_y,\Sigma_y)) \\ & = \frac{1}{2}[log\frac{|\Sigma_y|}{|\Sigma_x|} - d + tr(\Sigma_y^{-1}\Sigma_x) + (\mu_y-\mu_x)^T\Sigma_y^{-1}(\mu_y-\mu_x)] \end{aligned} \tag {5} \]

应用到公式(1)中的第三项,因此有

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)}||\mu_{\theta}(x_t,t) - \mu_{q}(x_t,x_0)||^2 \end{aligned} \tag {6} \]

其中\(\sigma_{q}^2(t)\)是公式(4)前的系数即\(\sigma_{q}^2(t)= \frac{ (1 - \alpha_t^2) (1 - \bar{\alpha}_{t-1}^2) }{ 1 - \bar{\alpha}_{t}^2 }\)

由于\(\mu_{\theta}(x_t,t)\)也是\(x_t\)的函数,因此,可以参考公式(3)的形式,将进一步假设

\[\mu_{\theta}(x_t, t) = \frac{\alpha_t(1-\bar{\alpha}_{t-1}^2) z_t + \bar{\alpha}_{t-1}( 1 - \alpha_t^2 ) z_{\theta}(z_t, t) }{ 1 - \bar {\alpha}_t^2 } \tag {7} \]

这样公式(6)进一步化简为

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{\bar{\alpha}_{t-1}^2( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)^2} ||z_{\theta}(z_t,t) - z_0||^2 \end{aligned} \tag {8} \]

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出其原来的样本。最终最小化公式(1)中的第三项,等价于最小化关于时间步的期望,因此有

\[arg min \mathbb{E}_{t \sim U\{2,T\}} [ \mathbb{E}_{q_{\phi}(z_t|z_0)}[ \mathbb{D}_{KL}(q_{\phi}(z_{t-1}|z_t,z_0)||p_{\theta}(z_{t-1}|z_t)) ] ] \]

形式二

\[z_t = \bar \alpha_t z_0 + \sqrt{1-\bar {\alpha}_t^2} \bar \epsilon_t \tag {9} \]

可得

\[z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}_t}{\bar {\alpha}_t} \tag {10} \]

再代入公式(3)得

\[\mu_q(x_t,x_0) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \bar \epsilon_t \tag{11} \]

参考形式一中的假设方式,可以假设

\[\mu_{\theta}(x_t,t) = \frac{1}{\alpha_t}x_t - \frac{1-\alpha_t^2}{\sqrt{1-\bar{\alpha}_t^2} \alpha_t} \epsilon_{\theta}(z_t, t) \tag{12} \]

再代入公式(6)可以得到

\[\begin{aligned} & \ \ \ \ \mathbb{D}_{KL} ( \mathcal N(z_{t-1};\mu_q(z_t,z_0),\Sigma_q(t)) || \mathcal N(z_{t-1};\mu_{\theta}(z_t,t),\Sigma_q(t))) \\ & = \frac{1}{2\sigma_{q}^2(t)} \frac{( 1 - \alpha_t^2 )^2}{ (1 - \bar {\alpha}_t^2)\alpha_t^2} ||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 \end{aligned} \tag {12} \]

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出按照公式(10)添加的原始噪音。

形式三

由公式(8)和公式(12)可以得到

\[||\epsilon_{\theta}(z_t,t) - \epsilon_t||^2 = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} ||z_{\theta}(z_t,t) - z_0||^2 \tag{13} \]

由于\(\bar {\alpha_t}, \sqrt{1-\bar {\alpha_t}^2}\) 分别是\(t\)时间步的加噪信号公式(9)中的原始信号和噪音信号系数,因此将信噪比SNR(t)定义为系数平方之比,即

\[SNR(t) = \frac{\bar{\alpha_t}^2}{1-\bar{\alpha_t}^2} \tag {14} \]

这个信噪比在时间步初期其值较大,代表真实信号占比多噪音占比少;在时间步后期其值较小,代表真实信号占比少噪音占比多。因为推理过程是完全从高斯分布随机取样,为了保证推理与训练保持一致,训练过程采取特定的\(\bar {\alpha}_t\)使得T步得到的是完全噪音,不包含任何原始信号。此时信噪比是0.

当预测发送在信噪比接近0(\(\bar \alpha_t \to 0\))时,模型原始预测是噪音\(\bar \epsilon\),因此根据公式(10)预估对应的原始信号

\[\bar z_0 = \frac{z_t - \sqrt{(1-\bar {\alpha}_t^2)} \bar {\epsilon}}{\bar {\alpha}_t} \]

这样网络预测的微小差异就会被放大很多倍,因此在论文[3]模型蒸馏过程,这就不是一个稳定的设计。为了避免这个问题,作者提出了3种解决办法。

  • 直接预测\(z\),而非噪音\(\epsilon\)
  • 同时预测\(z, \epsilon\),通过两个独立的输出通道\(z, \epsilon\)。由于根据公式(10)可以再由\(\epsilon\)再推断出\(z^{'}\),然后可以根据\(\bar \alpha_t^2, 1-\bar \alpha_t^2\)对这两个值进行差值。
  • 预测混合体 \(v=\alpha_t\epsilon - \sqrt{1-\alpha_t^2}z\)

参考

[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models

posted @ 2024-06-16 18:54  星辰大海,绿色星球  阅读(145)  评论(0编辑  收藏  举报