Denoising Diffusion Probabilistic Models去噪扩散模型(DDPM)

Denoising Diffusion Probabilistic Models去噪扩散模型(DDPM)

2024/2/28

论文链接:Denoising Diffusion Probabilistic Models(neurips.cc)

这篇文章对DDPM写个大概,公式推导会放在以后的文章里。

一、引言 Introduction

各类深度生成模型在多种数据模态上展示了高质量的样本。生成对抗网络(GANs)、自回归模型、流模型变分自编码器(VAEs)已经合成了引人注目的图像和音频样本。此外,在基于能量的建模得分匹配方面也取得了显著进展,生成的图像与GANs生成的图像相当。

扩散概率模型是一个参数化马尔科夫链,使用变分推断(Variational Inference)进行训练,以便在有限时间内产生于数据相匹配的样本。这个链的转移是学习来逆转扩散过程的,扩散过程是一种马尔可夫链,它逐渐向与采样相反的方向添加噪声到数据中,直到信号被破坏。当扩散包含的是少量的高斯噪声时,只需将采样链转移设置为条件高斯分布,这样就可以实现一个特别简单的神经网络参数化。

变分推断(Variational Inference):这是一种用于估计概率模型参数的统计方法。它通过优化一个目标函数来近似真实的后验分布,这个目标函数通常是真实后验分布与一个易于计算的分布(变分分布)之间的差异。

流模型(Flows):流模型是一种生成模型,它通过一系列可逆的变换(称为流)将数据从高维空间映射到低维空间,然后再映射回高维空间,以生成新的数据样本。流模型的优势在于其变换是可逆的,这有助于保持数据的多样性。

能量基建模(Energy-based Modeling):这是一种基于能量函数的建模方法,通常用于二分类问题。能量函数定义了输入数据与特定标签的不匹配程度。在图像生成的背景下,能量基模型可以用来评估和改进生成图像的质量。

得分匹配(Score Matching):这是一种用于训练生成模型的技术,特别是在概率密度估计中。它涉及计算真实数据分布的得分函数,并使生成模型的得分函数与之匹配,以此来提高生成样本的质量。

二、模型具体细节

扩散是指物质粒子从高浓度区域向低浓度区域移动的过程,扩散模型的灵感来自非平衡热力学,扩散模型想做的就是通过向图片中加入高斯噪声模拟这个过程,最后通过逆向过程从随机噪声中生成图片。

2.1 前向加噪

我们需要进行随机采样生成和图片尺寸大小相同的噪声图片。噪声图片中所有通道数值遵从正态分布。我们根据T步将生成的噪声图片与原图片进行混合,每一步的混合方式满足以下公式:

β×ϵ+1β×x

其中,x为原始图片,ϵ是高斯噪声,β是一个介于[0.0,1.0]之间的数字,用于产生xϵ前的系数。

我们输入x0套用公式后我们得到了x1

x1=β1×ϵ1+1β1×x0

image

输入x1套用公式后我们得到了x2

x2=β2×ϵ2+1β2×x1

image

......

以此类推,我们可以得到前一时刻与后一时刻的关系:

xt=βt×ϵt+1βt×xt1

其中ϵt都是基于标准正态分布重新采样的随机数,而其中的βt是从一个接近0的数字逐步递增,最后趋近于1,0<β1<β2<β3<βt1<βt<1​​.

有:

q(xt|xt1)=N(xt;1βtxt1,βtI)

随着步长t增加,原来的样本x0的特征变得不可区分。当$T\to\infty \mathbf{x}_T$等价于各相同性高斯分布。

image

过程如上图所示,上诉过程有一个很好的特性,可以使用重参数化技巧(reparameterization trick)(参见VAE),在任何任意时间步长t上采样xt​。

为了简化后续的推导,我们引入一个新变量αt=1βt,上诉公式变为:

xt=1αt×ϵt+αt×xt1

接下来需要思考的是通过公式能否使x0直接得到xT,我们从

xt=1αt×ϵt+αt×xt1

向后推,得到:

xt=αtxt1+1αtϵt1;其中, ϵt1,ϵt2,N(0,I)=αtαt1xt2+1αtαt1ϵ¯t2;其中, ϵ¯t2 合并两个高斯量 ().==α¯tx0+1α¯tϵ

其中,α¯t=i=1tαi

()当我们合并两个具有不同方差的高斯量N(0,σ12I)N(0,σ22I)时,新的分布是N(0,(σ12+σ22)I),这里合并的标准差是(1αt)+αt(1αt1)=1αtαt1

经过推导我们可以得到公式:

xt=1α¯t×ϵ+α¯t×x0

通常,当样本变得更嘈杂时,我们可以承受更大的更新步骤,因此

β1<β2<<βT

α¯1>>α¯T

2.2 反向过程

反向过程的目的是将有噪声的图片恢复成原始图片,如果我们可以反转上述过程,从q(xt1|xt)中采样,将可以从高斯噪声中生成图片。因为前向加噪是一个随机过程,所以反向过程也是一个随机过程,所以我们可以用P(xt1|xt)表示在给定xt的情况下,前一时刻xt1的概率,根据贝叶斯公式有:

P(xt1|xt,x0)=P(xt|xt1,x0)P(xt1|x0)P(xt|x0)

根据公式:

xt=1αt×ϵt+αt×xt1xt=1α¯t×ϵ+α¯t×x0

我们可以得到xt是分别满足N(αtxt1,1αt)N(α¯tx0,1α¯t)的正态分布(因为噪声ϵ是满足高斯分布的),xt1是满足N(α¯t1x0,1α¯t1)的正态分布。我们可以将上式改为:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)(q(xtxt1,x0)N(xt;αtxt1,(1αt)I))(q(xt1x0)N(xt1;α¯t1x0,(1α¯t1)I))(q(xtx0)N(xt;α¯tx0,(1α¯t)I))exp(12((xtαtxt1)2βt+(xl1α¯l1x0)21α¯t1(xtα¯tx0)21α¯t))=exp(12(xl22αtxtxt1+αtxt12βt+xt122α¯t1x0xt1+α¯t1x021α¯t1(xtα¯tx0)21α¯t))=exp(12((αlβt+11α¯t)xt12(2αlβtxt+2α¯t11α¯tx0)xl1+C(xl,x0)))

其中C(xt,x0)不涉及xt1某些功能,省略了详细信息。

从中我们可以得知P(xt1|xt,x0)是满足N(at(1a¯t1)1a¯txt+a¯t1(1at)1a¯t×xt1a¯t×ϵa¯t,(βt(1a¯t1)1a¯t)2)

这里只要我们知道了ϵ就可以知道前一个时刻的图像,这里我们训练一个神经网络模型,来预测此图像相对于x0原图所加入的噪声。

根据实验可知,xT是一任何张满足标准正态分布的噪声图片。我们使用标准正态分布随机采样就能得到xT​。

反向过程通过T步从p(xT)=N(xT;0,I)开始的噪声。

pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t))pθ(x0:T)=pθ(xT)t=1Tpθ(xt1|xt)pθ(x0)=pθ(x0:T)dx1:T

其中θ是我们训练的参数。

2.3 Loss损失

文中对负对数似然上优化了ELBO(来自琴生不等式)

E[logpθ(x0)]Eq[logpθ(x0:T)q(x1:T|x0)]=L

损失可以按如下方式重写:

L=Eq[logpθ(x0:T)q(x1:T|x0)]=Eq[logp(xT)t=1Tlogpθ(xt1|xt)q(xt|xt1)]=Eq[logp(xT)q(xT|x0)t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)logpθ(x0|x1)]=Eq[DKL(q(xT|x0)||p(xT))+t=2TDKL(q(xt1|xt,x0)||pθ(xt1|xt))logpθ(x0|x1)]

因为我们保持β1,,βT恒定,所以DKL(q(xT|x0)||p(xT))也是恒定的。

2.4 计算 DKL(q(xt1|xt,x0)pθ(xt1|xt))

在给定初始x0的条件下,前向过程的后验概率为:

q(xt1|xt,x0)=N(xt1;μ~t(xt,x0),β~tI)μ~t(xt,x0)=α¯t1βt1αt¯x0+αt(1α¯t1)1αt¯xtβ~t=1α¯t11αt¯βt

论文中设置Σθ(xt,t)=σt2I,其中σt2设置为常量βtβt~

然后,

pθ(xt1|xt)=N(xt1;μθ(xt,t),σt2I)

对于给定的噪声ϵN(0,I),使用q(xt|x0)

xt(x0,ϵ)=αt¯x0+1αt¯ϵx0=1α¯t(xt(x0,ϵ)1α¯tϵ)

这里,

Lt1=DKL(q(xt1|xt,x0)pθ(xt1|xt))=Eq[12σt2μ~(xt,x0)μθ(xt,t)2]=Ex0,ϵ[12σt21αt(xt(x0,ϵ)βt1αt¯ϵ)μθ(xt(x0,ϵ),t)2]

使用模型重新参数化以预测噪声

μθ(xt,t)=μ~(xt,1α¯t(xt1α¯tϵθ(xt,t)))=1αt(xtβt1αt¯ϵθ(xt,t))

其中 ϵθ 是预测 其中 ϵθ 是预测 ϵ 给定 (xt,t) 的学习函数。

这里给定,

Lt1=Ex0,ϵ[βt22σt2αt(1α¯t)ϵϵθ(α¯tx0+1α¯tϵ,t)2]

用来训练预测噪声。

2.5 简化损失

Lsimple(θ)=Et,x0,ϵ[ϵϵθ(α¯tx0+1α¯tϵ,t)2]

这在t=1时最小化logpθ(x0|x1),并且在t>1时最小化Lt1,同时丢弃Lt1中的权重。

丢弃权重βt22σt2αt(1αt¯)​会增加给予更高 t (具有更高噪声水平) 的权重,从而提高样本质量。

三、代码实现

Denoise Diffusion 降噪扩散

1. 代码解析

1. 初始化

注意:以下代码块都是在DenoiseDiffusion类中

eps_modelϵθ(xt,t)模型

n_stepst

device是放置常量的设备

class DenoiseDiffusion:
    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        super().__init__()
        self.eps_model = eps_model
        
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        
        self.alpha = 1. - self.beta
        
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.n_steps = n_steps
        
        self.sigma2 = self.beta        

为了方便代码理解,这里将class DenoiseDiffusion拆分进行解释,理解代码每一步在做什么。

self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)这里是生成了一个tensor,该tensor包含n_steps个数据,包含从 0.00010.02 的等间隔数值,代表了公式中的 β1,,βT

self.alpha = 1. - self.beta代表 αt=1βt

self.alpha_bar = torch.cumprod(self.alpha, dim=0)代表 αt¯=s=1tαs

self.n_steps = n_steps代表 T

self.sigma2 = self.beta代表 σ2=β

2. 获取q(xt|x0)​分布

关于公式 q(xt|x0)=N(xt;α¯tx0,(1α¯t)I) 的代码实现

    #该函数返回一个包含两个张量的元组
    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        
        var = 1 - gather(self.alpha_bar, t)
        
        return mean, var

gather 这个操作会根据 t 中的索引从 self.alpha_bar 中提取元素。t是索引张量,包含了要提取的元素的索引。

mean = gather(self.alpha_bar, t) ** 0.5 * x0计算 α¯tx0

var = 1 - gather(self.alpha_bar, t)计算 (1α¯t)I

3. 来自q(xt|x0)​的样本

关于公式 q(xt|x0)=N(xt;α¯tx0,(1α¯t)I) 的代码实现

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        if eps is None:
            eps = torch.randn_like(x0)
            
        mean, var = self.q_xt_x0(x0, t)
        
        return mean + (var ** 0.5) * eps

上述代码中if eps is None:所包含的内容代表 ϵN(0,I)

mean, var = self.q_xt_x0(x0, t)代表获取 q(xt|x0)

最后返回来自 q(xt|x0) 的样本

4. 来自pθ(xt1|xt)的样本

这段代码实现公式

pθ(xt1|xt)=N(xt1;μθ(xt,t),σt2I)μθ(xt,t)=1αt(xtβt1αt¯ϵθ(xt,t))

    def p_sample (self, xt: torch.Tensor, t: torch.Tensor):
        eps_theta = self.eps_model(xt, t)
        
        alpha_bar = gather(self.alpha_bar, t)
        
        alpha = gather(self.alpha, t)
        
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        
        var  = gather(self.sigma2, t)
        
        eps = torch.randn(xt.shape, device=xt.device)
        
        return mean + (var ** .5) * eps

上述代码中,eps_theta = self.eps_model(xt, t) 表示ϵθ(xt,t)

alpha_bar = gather(self.alpha_bar, t) 是在收集α¯t

alpha = gather(self.alpha, t) 表示αt

eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5 表示β1αt

mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta) 计算的是1αt(xtβt1αt¯ϵθ(xt,t))

var = gather(self.sigma2, t) 表示的是σ2

eps = torch.randn(xt.shape, device=xt.device)代表 ϵN(0,I)

最后return mean + (var ** .5) * eps返回样本。

5. 简化损失

这段代码实现的是 Lsimple(θ)=Et,x0,ϵ[ϵϵθ(α¯tx0+1α¯tϵ,t)2] 公式

    def loss(self, x0: Tensor, noise: Optional[torch.Tensor] = None):
        
        batch_size - x0.shape[0]
        
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        
        if noise is None:
            noise = torch.randn_like(x0)
        
        xt = self.q_sample(x0, t, eps=noise)
        
        eps_theta = self.eps_model(xt, t)
        
        return F.mse_loss(noise, eps_theta)

上述代码中,batch_size - x0.shape[0]是为了获取批量大小。

t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)是对批次中的每个样品得到随机的 t

if noise is None:中的代表着 ϵN(0,I)

xt = self.q_sample(x0, t, eps=noise)xtq(xt|x0)中得到的样本。

eps_theta = self.eps_model(xt, t)是获取公式 ϵθ(αt¯x0+1αt¯ϵ,t)

最后return F.mse_loss(noise, eps_theta)返回MSE损失。

2. 完整代码

下面是完整的Denoise Diffusion代码

from typing import Tuple, Optional

import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn

from labml_nn.diffusion.ddpm.utils import gather


class DenoiseDiffusion:
    """
    ## Denoise Diffusion
    """

    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        
        super().__init__()
        self.eps_model = eps_model
        
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        
        self.alpha = 1. - self.beta
        
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.n_steps = n_steps
        
        self.sigma2 = self.beta

    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        
        var = 1 - gather(self.alpha_bar, t)
        
        return mean, var

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        
        if eps is None:
            eps = torch.randn_like(x0)
            
        mean, var = self.q_xt_x0(x0, t)
        
        return mean + (var ** 0.5) * eps

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        
        eps_theta = self.eps_model(xt, t)
        
        alpha_bar = gather(self.alpha_bar, t)
        
        alpha = gather(self.alpha, t)
        
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        
        var = gather(self.sigma2, t)

        eps = torch.randn(xt.shape, device=xt.device)
        
        return mean + (var ** .5) * eps

    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
        
        batch_size = x0.shape[0]
        
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        
        if noise is None:
            noise = torch.randn_like(x0)
            
        xt = self.q_sample(x0, t, eps=noise)
        
        eps_theta = self.eps_model(xt, t)
        
        return F.mse_loss(noise, eps_theta)

参考文献

[1].Diffusion Models 10 篇必读论文(1)DDPM - 知乎 (zhihu.com)

[2].去噪扩散模型

[3].[What are Diffusion Models? | Lil'Log (lilianweng.github.io)](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#:~:text=Diffusion models are inspired by,data samples from the noise)

posted @   TTS-S  阅读(852)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示