扩散模型DDPM根据代码尝试进行形象的理解

扩散模型可以用于生成图形。SBM(Score-Based Model)和DDPM(Denoising Diffusion Probabilistic Model)是两种常见的扩散模型。本文依据Github上的极简代码,尝试理解DDPM。

DDPM 的基本思想是在训练阶段将数据逐渐加上噪声(扩散过程),然后在预测阶段再一步步去除噪声(反向扩散过程),得到真实数据。

训练阶段#

假设一张图片通过加噪的过程完全变成噪声需要T步,每一步加一点点噪声,直到最后完全变成噪声。

训练阶段,给模型的输入值是加噪的步数t、以及加噪后的带噪声图像x_t,预测输出值是所加的噪声。这样,在预测阶段,根据每一步预测的噪声,就可以一步步去噪、还原图像。

输入

MAE监督预测

原始图x_0

...

带噪声的图x_t-1

加权和

随机噪声

带噪声的图x_t

Model

....

纯噪声图x_T

初始化#

以下是训练器的类的初始化定义:

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

其中:

  • self.model:传入的 model 是U-Net。U-Net是一种全卷积网络,先通过编码器提取深层特征,然后通过解码器逐步恢复图像的空间分辨率,并通过跳跃连接来融合编码器和解码器的特征,以生成高质量的输出图像。
  • self.T:扩散的步数,此处默认是1000。
  • self.betas:从1e-4到2e-2的长度为 T 的等差数列。
  • alphas_bar:alphas 的累乘,是0.9999到约4e-5的递减数列。
  • self.sqrt_alphas_bar:alphas_bar 的开根号,递减数列。
  • self.sqrt_one_minus_alphas_bar:(1 - alphas_bar) 的开根号,递增数列。

前向传播#

    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss

其中:

  • x_0:传入的图像矩阵,形为 B C H W,如 [80, 3, 32, 32]
  • t:为每张图像随机选取一个 T 以内的扩散步数,训练这一步的模型预测能力。
  • noise:随机噪声。
  • x_t:第t步的带噪声的图像。计算方式是xt=1βtx0+βtϵt,其中 βt​ 是一个时间相关的递增参数,ϵt​ 是一个标准高斯噪声。随着步数t的增加,真实图像x0的权重越来越小,噪声的权重越来越大,象征着一步步加噪的过程。

通过这部分的训练,可以使self.model具有预测噪声的能力。

预测阶段#

class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)   

关注Algorithm 2部分。首先,输入值x_T是完全的噪声。去噪过程通过逆向时间步循环来实现,从最高时间步 T-1 开始,逐步回到第0时间步。

在每个时间步,计算均值和方差。然后,最新的x_t一部分由均值mean组成,另一部分由一个新的噪声组成,噪声权重是var。可以预见,随着时间步的推移,噪声权重会越来越低,直至0。最后一步的均值就是去噪完成的图像。

关于均值和方差,参见函数 p_mean_variance()

方差是从一开始就计算好的方差序列中取值,该序列 self.posterior_var == [0, 5.4532e-5, ..., 2e-2] 是一个递增数列,那么逆向时间步循环下就是个递减数列,代表方差越来越小。

eps 是模型预测出的图像附带的噪声。

均值的计算利用函数 predict_xt_prev_mean_from_eps() 进行。均值 mean = coeff1 * x_t - coeff2 * eps,表示图像去除一部分噪声。随着时间步的推移,coeff1 从1.0102逐渐递减到1.0001,coeff2 从0.0202缓慢递减到0.0063再快速增加到0.01。

通俗来讲,去噪过程就是去一点噪、加一点噪、去一点噪、加一点噪,加的噪声越来越小,最终实现完全去噪的过程。

输入

预测

纯噪声图x_T

...

带噪声的图x_t

Model

图中的噪声

加权差

mean

加权和

标准差*noise

带噪声的图x_t-1

....

还原的图x_0

推荐后续阅读#

Stable Diffusion 解读(三):原版实现及Diffusers实现源码解读 | 周弈帆的博客 (zhouyifan.net)

posted @   板子~  阅读(204)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示
主题色彩
选择模式