Diffusion系列 - DDIM 公式推导 + 代码 -(三)
DENOISING DIFFUSION IMPLICIT MODELS (DDIM)
从DDPM中我们知道,其扩散过程(前向过程、或加噪过程)被定义为一个马尔可夫过程,其去噪过程(也有叫逆向过程)也是一个马尔可夫过程。对马尔可夫假设的依赖,导致重建每一步都需要依赖上一步的状态,所以推理需要较多的步长。
DDPM中对于其逆向分布的建模使用马尔可夫假设,这样做的目的是将式子中的未知项 q(xt|xt−1,x0),转化成了已知项 q(xt|xt−1),最后求出 q(xt−1|xt,x0) 的分布也是一个高斯分布 N(xt−1;μq(xt,x0),Σq(t))。
从DDPM的结论出发,我们不妨直接假设 q(xt−1|xt,x0) 的分布为高斯分布,在不使用马尔可夫假设的情况下,尝试求解 q(xt−1|xt,x0) 。
由 DDPM 中 q(xt−1|xt,x0) 的分布 N(xt−1;μq(xt,x0),Σq(t)) 可知,均值为 一个关于 xt,x0 的函数,方差为一个关于 t 的函数。
我们可以把 q(xt−1|xt,x0) 设计成如下分布:
这样,只要求解出 a,b,σt 这三个待定系数,即可确定 q(xt−1|xt,x0) 的分布。
重参数化 q(xt−1|xt,x0) :
假设训练模型时输入噪声图片的加噪参数与DDPM完全一致
由 q(xt|x0):=N(xt;√ˉαtx0,(1−ˉαt)I) :
代入 xt 有:
又:
观察系数可以得到方程组:
三个未知数 两个方程,可以用 σt 表示 a,b:
a,b 代入 q(xt−1|xt,x0):=N(xt−1;ax0+bxt,σ2tI)
又
代入 x0 有:
通过观察 xt−1 的分布,我们建模采样分布为高斯分布:
并且均值和方差也采用相似的形式:
其中 ϵθ(xt,t) 为预测的噪声。
此时,确定优化目标只需要 q(xt−1|xt,x0) 和 pθ(xt−1|xt) 两个分布尽可能相似,使用KL散度来度量,则有:
恰好与DDPM的优化目标一致,所以我们可以直接复用DDPM训练好的模型。
pθ 的采样步骤则为:
令 σt=η√(1−αt)(1−ˉαt−1)1−ˉαt
当 η=1 时,前向过程为 Markovian ,采样过程变为 DDPM 。
当 η=0 时,采样过程为确定过程,此时的模型 称为 隐概率模型(implicit probabilstic model)。
DDIM如何加速采样:
在 DDPM 中,基于马尔可夫链 t 与 t−1 是相邻关系,例如 t=100 则 t−1=99;
在 DDIM 中,t 与 t−1 只表示前后关系,例如 t=100 时,t−1 可以是 90 也可以是 80、70,只需保证 t−1<t 即可。
此时构建的采样子序列 τ=[τi,τi−1,⋯,τ1]≪[t,t−1,⋯,1] 。
例如,原序列 T=[100,99,98,⋯,1],采样子序列为 τ=[100,90,80,⋯,1] 。
DDIM 采样公式为:
当 η=0 时,DDIM 采样公式为:
代码实现
训练过程与 DDPM 一致,代码参考上一篇文章。采样代码如下:
device = 'cuda'
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval()
image_size=96
epochs = 500
batch_size = 128
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000])
# 每隔20采样一次
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003
# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod)
def get_val_by_index(val, t, x_shape):
batch_t = t.shape[0]
out = val.gather(-1, t)
return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1])
def p_sample_ddim(model):
def step_denoise(model, x_tau_i, tau_i, tau_i_1):
sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)
denoise = model(x_tau_i, tau_i)
if eta == 0:
sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
+ (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
* denoise
return x_tau_i_1
sigma = eta * torch.sqrt((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
(1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape))
noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device)
# 整个式子由三部分组成
c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)
c2 = torch.sqrt(1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
c3 = sigma * noise_z
x_tau_i_1 = c1 + c2 + c3
return x_tau_i_1
img_pred = torch.randn((4, 3, image_size, image_size), device=device)
for k in range(0, len(tau_index)):
# print(tau_index)
# 因为 tau_index 是倒序的,tau_i = k, tau_i_1 = k+1,这里不能弄反
tau_i_1 = torch.tensor([tau_index[k+1]], device=device, dtype=torch.long)
tau_i = torch.tensor([tau_index[k]], device=device, dtype=torch.long)
img_pred = step_denoise(model, img_pred, tau_i, tau_i_1)
torch.cuda.empty_cache()
if tau_index[k+1] == 0: return img_pred
return img_pred
with torch.no_grad():
img = p_sample_ddim(model)
img = torch.clamp(img, -1.0, 1.0)
show_img_batch(img.detach().cpu())
DDIM
https://arxiv.org/pdf/2010.02502
https://github.com/ermongroup/ddim
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 因为Apifox不支持离线,我果断选择了Apipost!
· 通过 API 将Deepseek响应流式内容输出到前端