扩散模型DDPM根据代码尝试进行形象的理解
扩散模型可以用于生成图形。SBM(Score-Based Model)和DDPM(Denoising Diffusion Probabilistic Model)是两种常见的扩散模型。本文依据Github上的极简代码,尝试理解DDPM。
DDPM 的基本思想是在训练阶段将数据逐渐加上噪声(扩散过程),然后在预测阶段再一步步去除噪声(反向扩散过程),得到真实数据。
训练阶段
假设一张图片通过加噪的过程完全变成噪声需要T步,每一步加一点点噪声,直到最后完全变成噪声。
训练阶段,给模型的输入值是加噪的步数t、以及加噪后的带噪声图像x_t,预测输出值是所加的噪声。这样,在预测阶段,根据每一步预测的噪声,就可以一步步去噪、还原图像。
初始化
以下是训练器的类的初始化定义:
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
其中:
- self.model:传入的
model
是U-Net。U-Net是一种全卷积网络,先通过编码器提取深层特征,然后通过解码器逐步恢复图像的空间分辨率,并通过跳跃连接来融合编码器和解码器的特征,以生成高质量的输出图像。 - self.T:扩散的步数,此处默认是1000。
- self.betas:从1e-4到2e-2的长度为
T
的等差数列。 - alphas_bar:
alphas
的累乘,是0.9999到约4e-5的递减数列。 - self.sqrt_alphas_bar:
alphas_bar
的开根号,递减数列。 - self.sqrt_one_minus_alphas_bar:
(1 - alphas_bar)
的开根号,递增数列。
前向传播
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
其中:
- x_0:传入的图像矩阵,形为
B C H W
,如[80, 3, 32, 32]
。 - t:为每张图像随机选取一个
T
以内的扩散步数,训练这一步的模型预测能力。 - noise:随机噪声。
- x_t:第t步的带噪声的图像。计算方式是\(x_{t} = \sqrt{1-\beta_t} x_{0} + \sqrt{\beta_t} \mathbf{\epsilon}_t\),其中 \(\beta_t\) 是一个时间相关的递增参数,\(\mathbf{\epsilon}_t\) 是一个标准高斯噪声。随着步数t的增加,真实图像\(x_0\)的权重越来越小,噪声的权重越来越大,象征着一步步加噪的过程。
通过这部分的训练,可以使self.model具有预测噪声的能力。
预测阶段
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
关注Algorithm 2部分。首先,输入值x_T是完全的噪声。去噪过程通过逆向时间步循环来实现,从最高时间步 T-1 开始,逐步回到第0时间步。
在每个时间步,计算均值和方差。然后,最新的x_t一部分由均值mean组成,另一部分由一个新的噪声组成,噪声权重是\(\sqrt{var}\)。可以预见,随着时间步的推移,噪声权重会越来越低,直至0。最后一步的均值就是去噪完成的图像。
关于均值和方差,参见函数 p_mean_variance()
。
方差是从一开始就计算好的方差序列中取值,该序列 self.posterior_var == [0, 5.4532e-5, ..., 2e-2]
是一个递增数列,那么逆向时间步循环下就是个递减数列,代表方差越来越小。
eps
是模型预测出的图像附带的噪声。
均值的计算利用函数 predict_xt_prev_mean_from_eps()
进行。均值 mean = coeff1 * x_t - coeff2 * eps
,表示图像去除一部分噪声。随着时间步的推移,coeff1
从1.0102逐渐递减到1.0001,coeff2
从0.0202缓慢递减到0.0063再快速增加到0.01。
通俗来讲,去噪过程就是去一点噪、加一点噪、去一点噪、加一点噪,加的噪声越来越小,最终实现完全去噪的过程。
推荐后续阅读
Stable Diffusion 解读(三):原版实现及Diffusers实现源码解读 | 周弈帆的博客 (zhouyifan.net)