Loading...

扩散模型DDPM根据代码尝试进行形象的理解

扩散模型可以用于生成图形。SBM(Score-Based Model)和DDPM(Denoising Diffusion Probabilistic Model)是两种常见的扩散模型。本文依据Github上的极简代码,尝试理解DDPM。

DDPM 的基本思想是在训练阶段将数据逐渐加上噪声(扩散过程),然后在预测阶段再一步步去除噪声(反向扩散过程),得到真实数据。

训练阶段

假设一张图片通过加噪的过程完全变成噪声需要T步,每一步加一点点噪声,直到最后完全变成噪声。

训练阶段,给模型的输入值是加噪的步数t、以及加噪后的带噪声图像x_t,预测输出值是所加的噪声。这样,在预测阶段,根据每一步预测的噪声,就可以一步步去噪、还原图像。

graph LR A{原始图x_0} --> ... ... --> B{带噪声的图x_t-1} B --> 加权和 随机噪声 --> 加权和 加权和 --> C{带噪声的图x_t} C -.输入.-> Model Model -.MAE监督预测.-> 随机噪声 C --> .... .... --> D{纯噪声图x_T}

初始化

以下是训练器的类的初始化定义:

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

其中:

  • self.model:传入的 model 是U-Net。U-Net是一种全卷积网络,先通过编码器提取深层特征,然后通过解码器逐步恢复图像的空间分辨率,并通过跳跃连接来融合编码器和解码器的特征,以生成高质量的输出图像。
  • self.T:扩散的步数,此处默认是1000。
  • self.betas:从1e-4到2e-2的长度为 T 的等差数列。
  • alphas_bar:alphas 的累乘,是0.9999到约4e-5的递减数列。
  • self.sqrt_alphas_bar:alphas_bar 的开根号,递减数列。
  • self.sqrt_one_minus_alphas_bar:(1 - alphas_bar) 的开根号,递增数列。

前向传播

    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss

其中:

  • x_0:传入的图像矩阵,形为 B C H W,如 [80, 3, 32, 32]
  • t:为每张图像随机选取一个 T 以内的扩散步数,训练这一步的模型预测能力。
  • noise:随机噪声。
  • x_t:第t步的带噪声的图像。计算方式是\(x_{t} = \sqrt{1-\beta_t} x_{0} + \sqrt{\beta_t} \mathbf{\epsilon}_t​\),其中 \(\beta_t\)​ 是一个时间相关的递增参数,\(\mathbf{\epsilon}_t\)​ 是一个标准高斯噪声。随着步数t的增加,真实图像\(x_0\)的权重越来越小,噪声的权重越来越大,象征着一步步加噪的过程。

通过这部分的训练,可以使self.model具有预测噪声的能力。

预测阶段

class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)   

关注Algorithm 2部分。首先,输入值x_T是完全的噪声。去噪过程通过逆向时间步循环来实现,从最高时间步 T-1 开始,逐步回到第0时间步。

在每个时间步,计算均值和方差。然后,最新的x_t一部分由均值mean组成,另一部分由一个新的噪声组成,噪声权重是\(\sqrt{var}\)。可以预见,随着时间步的推移,噪声权重会越来越低,直至0。最后一步的均值就是去噪完成的图像。

关于均值和方差,参见函数 p_mean_variance()

方差是从一开始就计算好的方差序列中取值,该序列 self.posterior_var == [0, 5.4532e-5, ..., 2e-2] 是一个递增数列,那么逆向时间步循环下就是个递减数列,代表方差越来越小。

eps 是模型预测出的图像附带的噪声。

均值的计算利用函数 predict_xt_prev_mean_from_eps() 进行。均值 mean = coeff1 * x_t - coeff2 * eps,表示图像去除一部分噪声。随着时间步的推移,coeff1 从1.0102逐渐递减到1.0001,coeff2 从0.0202缓慢递减到0.0063再快速增加到0.01。

通俗来讲,去噪过程就是去一点噪、加一点噪、去一点噪、加一点噪,加的噪声越来越小,最终实现完全去噪的过程。

graph LR A{纯噪声图x_T} --> ... ... --> B{带噪声的图x_t} B -.输入.-> Model Model -.预测.-> 图中的噪声 B --> 加权差 图中的噪声 --> 加权差 加权差 --> mean mean --> 加权和 标准差*noise --> 加权和 加权和 --> C{带噪声的图x_t-1} C --> .... .... --> D{还原的图x_0}

推荐后续阅读

Stable Diffusion 解读(三):原版实现及Diffusers实现源码解读 | 周弈帆的博客 (zhouyifan.net)

posted @ 2024-07-09 10:34  板子~  阅读(43)  评论(0编辑  收藏  举报