Denoising Diffusion Probabilistic Models去噪扩散模型(DDPM)

Denoising Diffusion Probabilistic Models去噪扩散模型(DDPM)

2024/2/28

论文链接:Denoising Diffusion Probabilistic Models(neurips.cc)

这篇文章对DDPM写个大概,公式推导会放在以后的文章里。

一、引言 Introduction

各类深度生成模型在多种数据模态上展示了高质量的样本。生成对抗网络(GANs)、自回归模型、流模型变分自编码器(VAEs)已经合成了引人注目的图像和音频样本。此外,在基于能量的建模得分匹配方面也取得了显著进展,生成的图像与GANs生成的图像相当。

扩散概率模型是一个参数化马尔科夫链,使用变分推断(Variational Inference)进行训练,以便在有限时间内产生于数据相匹配的样本。这个链的转移是学习来逆转扩散过程的,扩散过程是一种马尔可夫链,它逐渐向与采样相反的方向添加噪声到数据中,直到信号被破坏。当扩散包含的是少量的高斯噪声时,只需将采样链转移设置为条件高斯分布,这样就可以实现一个特别简单的神经网络参数化。

变分推断(Variational Inference):这是一种用于估计概率模型参数的统计方法。它通过优化一个目标函数来近似真实的后验分布,这个目标函数通常是真实后验分布与一个易于计算的分布(变分分布)之间的差异。

流模型(Flows):流模型是一种生成模型,它通过一系列可逆的变换(称为流)将数据从高维空间映射到低维空间,然后再映射回高维空间,以生成新的数据样本。流模型的优势在于其变换是可逆的,这有助于保持数据的多样性。

能量基建模(Energy-based Modeling):这是一种基于能量函数的建模方法,通常用于二分类问题。能量函数定义了输入数据与特定标签的不匹配程度。在图像生成的背景下,能量基模型可以用来评估和改进生成图像的质量。

得分匹配(Score Matching):这是一种用于训练生成模型的技术,特别是在概率密度估计中。它涉及计算真实数据分布的得分函数,并使生成模型的得分函数与之匹配,以此来提高生成样本的质量。

二、模型具体细节

扩散是指物质粒子从高浓度区域向低浓度区域移动的过程,扩散模型的灵感来自非平衡热力学,扩散模型想做的就是通过向图片中加入高斯噪声模拟这个过程,最后通过逆向过程从随机噪声中生成图片。

2.1 前向加噪

我们需要进行随机采样生成和图片尺寸大小相同的噪声图片。噪声图片中所有通道数值遵从正态分布。我们根据\(T\)步将生成的噪声图片与原图片进行混合,每一步的混合方式满足以下公式:

\[\begin{aligned}\sqrt{\beta}\times\epsilon+\sqrt{1-\beta}\times x\end{aligned} \]

其中,\(x\)为原始图片,\(\epsilon\)是高斯噪声,\(\beta\)是一个介于[0.0,1.0]之间的数字,用于产生\(x\)\(\epsilon\)前的系数。

我们输入\(x_0\)套用公式后我们得到了\(x_1\)

\[x_1=\sqrt{\beta_1}\times\epsilon_1+\sqrt{1-\beta_1}\times x_0 \]

image

输入\(x_1\)套用公式后我们得到了\(x_2\)

\[x_2=\sqrt{\beta_2}\times\epsilon_2+\sqrt{1-\beta_2}\times x_1 \]

image

......

以此类推,我们可以得到前一时刻与后一时刻的关系:

\[x_t=\sqrt{\beta_t}\times\epsilon_t+\sqrt{1-\beta_t}\times x_{t-1} \]

其中\(\epsilon_t\)都是基于标准正态分布重新采样的随机数,而其中的\(\beta_t\)是从一个接近0的数字逐步递增,最后趋近于1,\(0<\beta_1<\beta_2<\beta_3<\beta_{t-1}<\beta_t<1\)​​.

有:

\(q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I})\)

随着步长\(t\)增加,原来的样本\(x_0\)的特征变得不可区分。当$T\to\infty \(时,\)\mathbf{x}_T$等价于各相同性高斯分布。

image

过程如上图所示,上诉过程有一个很好的特性,可以使用重参数化技巧(reparameterization trick)(参见VAE),在任何任意时间步长\(t\)上采样\(x_t\)​。

为了简化后续的推导,我们引入一个新变量\(\alpha_t=1-\beta_t\),上诉公式变为:

\[x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \]

接下来需要思考的是通过公式能否使\(x_0\)直接得到\(x_T\),我们从

\[x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \]

向后推,得到:

\[\begin{aligned} \mathbf{x}_{t}& =\sqrt{\alpha_t}\mathbf{x}_{t-1}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} & ;\text{其中, }\boldsymbol{\epsilon}_{t-1},\boldsymbol{\epsilon}_{t-2},\cdots\sim\mathcal{N}(\mathbf{0},\mathbf{I}) \\ &=\sqrt{\alpha_t\alpha_{t-1}}\mathbf{x}_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\boldsymbol{\epsilon}}_{t-2}& ;\text{其中, }\bar{\boldsymbol{\epsilon}}_{t-2}\text{ 合并两个高斯量 }(*). \\ &=\ldots \\ &=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon} \end{aligned}\]

其中,\(\bar{\alpha}_t=\prod_{i=1}^t\alpha_i\)

\((*)\)当我们合并两个具有不同方差的高斯量\(\mathcal{N}(\mathbf{0},\sigma_1^2\mathbf{I})\)\(\mathcal{N}(\mathbf{0},\sigma_2^2\mathbf{I})\)时,新的分布是\(\mathcal{N}(\mathbf{0},(\sigma_1^2+\sigma_2^2)\mathbf{I})\),这里合并的标准差是\(\sqrt{(1-\alpha_t)+\alpha_t(1-\alpha_{t-1})}=\sqrt{1-\alpha_t\alpha_{t-1}}\)

经过推导我们可以得到公式:

\(\begin{aligned}x_t=\sqrt{1-\bar{\alpha}_t}\times\epsilon+\sqrt{\bar{\alpha}_t}\times x_0\end{aligned}\)

通常,当样本变得更嘈杂时,我们可以承受更大的更新步骤,因此

\(\begin{aligned}\beta_1<\beta_2<\cdots<\beta_T\end{aligned}\)

\(\bar{\alpha}_1>\cdots>\bar{\alpha}_T\)

2.2 反向过程

反向过程的目的是将有噪声的图片恢复成原始图片,如果我们可以反转上述过程,从\(q(\mathbf{x}_{t-1}|\mathbf{x}_t)\)中采样,将可以从高斯噪声中生成图片。因为前向加噪是一个随机过程,所以反向过程也是一个随机过程,所以我们可以用\(P(x_{t-1}|x_t)\)表示在给定\(x_t\)的情况下,前一时刻\(x_{t-1}\)的概率,根据贝叶斯公式有:

\(P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},x_0)=\frac{P(\boldsymbol{x}_t|\boldsymbol{x}_{t-1},x_0)P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{P(\boldsymbol{x}_t|\boldsymbol{x}_0)}\)

根据公式:

\(\begin{gathered} x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \\ x_t=\sqrt{1-\bar{\alpha}_t}\times\epsilon+\sqrt{\bar{\alpha}_t}\times x_0 \end{gathered}\)

我们可以得到\(x_t\)是分别满足\(N(\sqrt{\alpha_t}x_{t-1},1-\alpha_t)\)\(N(\sqrt{\bar{\alpha}_t}x_0,1-\bar{\alpha}_t)\)的正态分布(因为噪声\(\epsilon\)是满足高斯分布的),\(x_{t-1}\)是满足\(N(\sqrt{\bar{\alpha}_{t-1}}x_0,1-\bar{\alpha}_{t-1})\)的正态分布。我们可以将上式改为:

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t,\right. & \left.\mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \\ & \left(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_t ; \sqrt{\alpha_t} \mathbf{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right)\right) \\ & \left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_{t-1} ; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0,\left(1-\bar{\alpha}_{t-1}\right) \mathbf{I}\right)\right) \\ & \left(q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)\right) \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{\beta_t}+\frac{\left(\mathbf{x}_{l-1}-\sqrt{\bar{\alpha}_{l-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t 1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_l^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_l}{\beta_t}+\frac{1}{1-\bar{\alpha}_t}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_l}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_t} \mathbf{x}_0\right) \mathbf{x}_{l-1}+C\left(\mathbf{x}_l, \mathbf{x}_0\right)\right)\right) \end{aligned}\]

其中\(C(\mathbf{x}_t,\mathbf{x}_0)\)不涉及\(\mathbf{x}_{t-1}\)某些功能,省略了详细信息。

从中我们可以得知\(P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},x_0)\)是满足\(\begin{aligned}\boldsymbol{N}\left(\frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t+\frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}\times\frac{x_t-\sqrt{1-\bar{a}_t}\times\epsilon}{\sqrt{\bar{a}_t}},\left(\color{}{\sqrt{\frac{\beta_t(1-\bar{a}_{t-1})}{1-\bar{a}_t}}}\right)^2\right)\end{aligned}\)

这里只要我们知道了\(\epsilon\)就可以知道前一个时刻的图像,这里我们训练一个神经网络模型,来预测此图像相对于\(x_0\)原图所加入的噪声。

根据实验可知,\(x_T\)是一任何张满足标准正态分布的噪声图片。我们使用标准正态分布随机采样就能得到\(x_T\)​。

反向过程通过\(T\)步从\(p(x_T)=\mathcal{N}(x_T;\mathbf{0},\mathbf{I})\)开始的噪声。

\[\begin{aligned} \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)& =\mathcal{N}\left(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),{\Sigma_\theta(x_t,t)}\right) \\ \textcolor{lightgreen}{p_\theta}(x_{0:T})& =\textcolor{lightgreen}{p_\theta}(x_T)\prod_{t=1}^T\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t) \\ \textcolor{lightgreen}{p_\theta}(x_0)& =\int\textcolor{lightgreen}{p_\theta}(x_{0:T})dx_{1:T} \end{aligned}\]

其中\(\color{lightgreen}{\theta}\)是我们训练的参数。

2.3 Loss损失

文中对负对数似然上优化了ELBO(来自琴生不等式)

\[\begin{gathered} \mathbb{E}[-\log \textcolor{lightgreen}{p_\theta}(x_0)] \leq\mathbb{E}_q[-\log\frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)}] \\ =L \end{gathered}\]

损失可以按如下方式重写:

\[\begin{aligned} \text{L}& =\mathbb{E}_q[-\log\frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)}] \\ &=\mathbb{E}_q[-\log p(x_T)-\sum_{t=1}^T\log\frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})}] \\ &=\mathbb{E}_q[-\log\frac{p(x_T)}{q(x_T|x_0)}-\sum_{t=2}^T\log\frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}-\log\textcolor{lightgreen}{p_\theta}(x_0|x_1)] \\ &=\mathbb{E}_q[D_{KL}(q(x_T|x_0)||p(x_T))+\sum_{t=2}^TD_{KL}(q(x_{t-1}|x_t,x_0)||{\textcolor{lightgreen}{p_\theta}}(x_{t-1}|x_t))-\log \textcolor{lightgreen}{p_\theta}(x_0|x_1)] \end{aligned}\]

因为我们保持\(\beta_1,\ldots,\beta_T\)恒定,所以\(D_{KL}(q(x_T|x_0)||p(x_T))\)也是恒定的。

2.4 计算 \(D_{KL}(q(x_{t-1}|x_t,x_0)\|\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t))\)

在给定初始\(x_0\)的条件下,前向过程的后验概率为:

\[\begin{aligned} q(x_{t-1}|x_t,x_0)& =\mathcal{N}\left(x_{t-1};\tilde{\mu}_t(x_t,x_0),\tilde{\beta}_t\mathbf{I}\right) \\ \tilde{\mu}_t(x_t,x_0)& =\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha_t}}x_0+\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_t}}x_t \\ \tilde{\beta}_{t}& =\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha_t}}\beta_t \end{aligned}\]

论文中设置\(\textcolor{lightgreen}{\Sigma_\theta}(x_t,t)=\sigma_t^2\mathbf{I}\),其中\(\sigma_t^2\)设置为常量\(\beta_t\)\(\tilde{\beta_t}\)

然后,

\[\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),\sigma_t^2\mathbf{I}) \]

对于给定的噪声\(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\),使用\(q(x_t|x_0)\)

\[\begin{aligned} x_t(x_0,\epsilon)& =\sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha_t}}\epsilon \\ {x_0}& =\frac1{\sqrt{\bar{\alpha}_t}}\Big(x_t(x_0,\epsilon)-\sqrt{1-\bar{\alpha}_t}\epsilon\Big) \end{aligned}\]

这里,

\[\begin{aligned} L_{t-1}& =D_{KL}(q(x_{t-1}|x_t,x_0)\|\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)) \\ &=\mathbb{E}_q\left[\frac1{2\sigma_t^2}\left\|\tilde{\mu}(x_t,x_0)-\textcolor{lightgreen}{\mu_\theta}(x_t,t)\right\|^2\right] \\ &=\mathbb{E}_{x_0,\epsilon}\left[\frac1{2\sigma_t^2}\left\|\frac1{\sqrt{\alpha_t}}\left(x_t(x_0,\epsilon)-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\epsilon\right)-\textcolor{lightgreen}{\mu_\theta}(x_t(x_0,\epsilon),t)\right\|^2\right] \end{aligned}\]

使用模型重新参数化以预测噪声

\[\begin{gathered} \textcolor{lightgreen}{\mu_\theta}(x_t,t) =\tilde{\mu}\left(x_t,\frac1{\sqrt{\bar{\alpha}_t}}\left(x_t-\sqrt{1-\bar{\alpha}_t}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\right)\right) \\ =\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big) \end{gathered}\]

其中 \(\epsilon_\mathrm{\theta}\) 是预测 其中 \(\epsilon_\mathrm{\theta}\) 是预测 \(\epsilon\) 给定 \((x_t,t)\) 的学习函数。

这里给定,

\[L_{t-1}=\mathbb{E}_{x_0,\epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)}\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right] \]

用来训练预测噪声。

2.5 简化损失

\[L_{\mathrm{simple}}(\theta)=\mathbb{E}_{t,x_0,\epsilon}\left[\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right] \]

这在\(t=1\)时最小化\(-\log\textcolor{lightgreen}{p_\theta}(x_0|x_1)\),并且在\(t>1\)时最小化\(L_{t-1}\),同时丢弃\(L_{t-1}\)中的权重。

丢弃权重\(\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha_t})}\)​会增加给予更高 t (具有更高噪声水平) 的权重,从而提高样本质量。

三、代码实现

Denoise Diffusion 降噪扩散

1. 代码解析

1. 初始化

注意:以下代码块都是在DenoiseDiffusion类中

eps_model\(\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\)模型

n_steps\(t\)

device是放置常量的设备

class DenoiseDiffusion:
    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        super().__init__()
        self.eps_model = eps_model
        
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        
        self.alpha = 1. - self.beta
        
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.n_steps = n_steps
        
        self.sigma2 = self.beta        

为了方便代码理解,这里将class DenoiseDiffusion拆分进行解释,理解代码每一步在做什么。

self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)这里是生成了一个tensor,该tensor包含n_steps个数据,包含从 0.00010.02 的等间隔数值,代表了公式中的 \(\beta_1,\ldots,\beta_T\)

self.alpha = 1. - self.beta代表 \(\alpha_t=1-\beta_t\)

self.alpha_bar = torch.cumprod(self.alpha, dim=0)代表 \(\bar{\alpha_t}=\prod_{s=1}^t\alpha_s\)

self.n_steps = n_steps代表 \(T\)

self.sigma2 = self.beta代表 $\sigma^2=\beta $​

2. 获取\(q(x_t|x_0)\)​分布

关于公式 \(q(x_t|x_0)=\mathcal{N}\Big(x_t;\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)\mathbf{I}\Big)\) 的代码实现

    #该函数返回一个包含两个张量的元组
    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        
        var = 1 - gather(self.alpha_bar, t)
        
        return mean, var

gather 这个操作会根据 t 中的索引从 self.alpha_bar 中提取元素。t是索引张量,包含了要提取的元素的索引。

mean = gather(self.alpha_bar, t) ** 0.5 * x0计算 \(\sqrt{\bar{\alpha}_t}x_0\)

var = 1 - gather(self.alpha_bar, t)计算 \((1-\bar{\alpha}_t)\mathbf{I}\)

3. 来自\(q(x_t|x_0)\)​的样本

关于公式 \(q(x_t|x_0)=\mathcal{N}\Big(x_t;\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)\mathbf{I}\Big)\) 的代码实现

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        if eps is None:
            eps = torch.randn_like(x0)
            
        mean, var = self.q_xt_x0(x0, t)
        
        return mean + (var ** 0.5) * eps

上述代码中if eps is None:所包含的内容代表 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)

mean, var = self.q_xt_x0(x0, t)代表获取 \(q(x_t|x_0)\)

最后返回来自 \(q(x_t|x_0)\) 的样本

4. 来自\(\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)\)的样本

这段代码实现公式

\[\begin{aligned} \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)& =\mathcal{N}\left(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),\sigma_t^2\mathbf{I}\right) \\ \textcolor{lightgreen}{\mu_\theta}(x_t,t)& =\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big) \end{aligned}\]

    def p_sample (self, xt: torch.Tensor, t: torch.Tensor):
        eps_theta = self.eps_model(xt, t)
        
        alpha_bar = gather(self.alpha_bar, t)
        
        alpha = gather(self.alpha, t)
        
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        
        var  = gather(self.sigma2, t)
        
        eps = torch.randn(xt.shape, device=xt.device)
        
        return mean + (var ** .5) * eps

上述代码中,eps_theta = self.eps_model(xt, t) 表示\(\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\)

alpha_bar = gather(self.alpha_bar, t) 是在收集\(\bar{\alpha}_t\)

alpha = gather(self.alpha, t) 表示\(\alpha_{t}\)

eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5 表示\(\frac\beta{\sqrt{1-\overline{\alpha}t}}\)

mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta) 计算的是\(\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big)\)

var = gather(self.sigma2, t) 表示的是\(\sigma^2\)

eps = torch.randn(xt.shape, device=xt.device)代表 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)

最后return mean + (var ** .5) * eps返回样本。

5. 简化损失

这段代码实现的是 \(L_{\mathrm{simple}}(\theta)=\mathbb{E}_{t,x_0,\epsilon}\left[\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right]\) 公式

    def loss(self, x0: Tensor, noise: Optional[torch.Tensor] = None):
        
        batch_size - x0.shape[0]
        
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        
        if noise is None:
            noise = torch.randn_like(x0)
        
        xt = self.q_sample(x0, t, eps=noise)
        
        eps_theta = self.eps_model(xt, t)
        
        return F.mse_loss(noise, eps_theta)

上述代码中,batch_size - x0.shape[0]是为了获取批量大小。

t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)是对批次中的每个样品得到随机的 \(t\)

if noise is None:中的代表着 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)

xt = self.q_sample(x0, t, eps=noise)xt\(q(x_t|x_0)\)中得到的样本。

eps_theta = self.eps_model(xt, t)是获取公式 \(\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha_t}}\epsilon,t)\)

最后return F.mse_loss(noise, eps_theta)返回MSE损失。

2. 完整代码

下面是完整的Denoise Diffusion代码

from typing import Tuple, Optional

import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn

from labml_nn.diffusion.ddpm.utils import gather


class DenoiseDiffusion:
    """
    ## Denoise Diffusion
    """

    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        
        super().__init__()
        self.eps_model = eps_model
        
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        
        self.alpha = 1. - self.beta
        
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.n_steps = n_steps
        
        self.sigma2 = self.beta

    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        
        var = 1 - gather(self.alpha_bar, t)
        
        return mean, var

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        
        if eps is None:
            eps = torch.randn_like(x0)
            
        mean, var = self.q_xt_x0(x0, t)
        
        return mean + (var ** 0.5) * eps

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        
        eps_theta = self.eps_model(xt, t)
        
        alpha_bar = gather(self.alpha_bar, t)
        
        alpha = gather(self.alpha, t)
        
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        
        var = gather(self.sigma2, t)

        eps = torch.randn(xt.shape, device=xt.device)
        
        return mean + (var ** .5) * eps

    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
        
        batch_size = x0.shape[0]
        
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        
        if noise is None:
            noise = torch.randn_like(x0)
            
        xt = self.q_sample(x0, t, eps=noise)
        
        eps_theta = self.eps_model(xt, t)
        
        return F.mse_loss(noise, eps_theta)

参考文献

[1].Diffusion Models 10 篇必读论文(1)DDPM - 知乎 (zhihu.com)

[2].去噪扩散模型

[3].[What are Diffusion Models? | Lil'Log (lilianweng.github.io)](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#:~:text=Diffusion models are inspired by,data samples from the noise)

posted @ 2024-03-09 22:13  TTS-S  阅读(654)  评论(0编辑  收藏  举报