Loading web-font TeX/Math/Italic

Diffusion系列 - DDIM 公式推导 + 代码 -(三)

DENOISING DIFFUSION IMPLICIT MODELS (DDIM)

从DDPM中我们知道,其扩散过程(前向过程、或加噪过程)被定义为一个马尔可夫过程,其去噪过程(也有叫逆向过程)也是一个马尔可夫过程。对马尔可夫假设的依赖,导致重建每一步都需要依赖上一步的状态,所以推理需要较多的步长。

q(xt|xt1):=N(xt;αtxt1,1αtI)q(xt|x0):=N(xt;ˉαtx0,(1ˉαt)I)

q(xt1|xt,x0)Bayes=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)Markov=q(xt|xt1)q(xt1|x0)q(xt|x0)

DDPM中对于其逆向分布的建模使用马尔可夫假设,这样做的目的是将式子中的未知项 q(xt|xt1,x0),转化成了已知项 q(xt|xt1),最后求出 q(xt1|xt,x0) 的分布也是一个高斯分布 N(xt1;μq(xt,x0),Σq(t))

从DDPM的结论出发,我们不妨直接假设 q(xt1|xt,x0) 的分布为高斯分布,在不使用马尔可夫假设的情况下,尝试求解 q(xt1|xt,x0)

由 DDPM 中 q(xt1|xt,x0) 的分布 N(xt1;μq(xt,x0),Σq(t)) 可知,均值为 一个关于 xt,x0 的函数,方差为一个关于 t 的函数。

我们可以把 q(xt1|xt,x0) 设计成如下分布:

q(xt1|xt,x0):=N(xt1;ax0+bxt,σ2tI)

这样,只要求解出 a,b,σt 这三个待定系数,即可确定 q(xt1|xt,x0) 的分布。
重参数化 q(xt1|xt,x0)

xt1=ax0+bxt+σtεt1

假设训练模型时输入噪声图片的加噪参数与DDPM完全一致
q(xt|x0):=N(xt;ˉαtx0,(1ˉαt)I)

xt=ˉαtx0+1ˉαtεt

代入 xt 有:

xt1=ax0+b(ˉαtx0+1ˉαtεt)+σtεt1=(a+bˉαt)x0+(b1ˉαtεt+σtεt1)=(a+bˉαt)x0+(b2(1ˉαt)+σ2t)ˉεt1

又:

xt1=ˉαt1x0+1ˉαt1εt1

观察系数可以得到方程组:

{a+bˉαt=ˉαt1b2(1ˉαt)+σ2t=1ˉαt1

三个未知数 两个方程,可以用 σt 表示 a,b

{a=ˉαt1ˉαt1ˉαt1σ2t1ˉαtb=1ˉαt1σ2t1ˉαt

a,b 代入 q(xt1|xt,x0):=N(xt1;ax0+bxt,σ2tI)

q(xt1|xt,x0):=N(xt1;(ˉαt1ˉαt1ˉαt1σ2t1ˉαt)x0+(1ˉαt1σ2t1ˉαt)xtμq(xt,x0,t),σ2tI)

xt=ˉαtx0+1ˉαtˉε0x0=1ˉαtxt1ˉαtˉαtˉε0

代入 x0 有:

μq(xt,x0,t)=ˉαt1xt1ˉαtˉε0ˉαt+1ˉαt1σ2tˉε0

xt1=μq(xt,x0,t)+σtε0=ˉαt1xt1ˉαtˉε0ˉαtx0+1ˉαt1σ2tˉε0xt+σtε0

通过观察 xt1 的分布,我们建模采样分布为高斯分布:

pθ(xt1|xt):=N(xt1;μθ(xt,t),Σθ(xt,t)I)

并且均值和方差也采用相似的形式:

μθ(xt,t)=ˉαt1xt1ˉαtϵθ(xt,t)ˉαt+1ˉαt1σ2tϵθ(xt,t)Σθ(xt,t)=σ2t

其中 ϵθ(xt,t) 为预测的噪声。

此时,确定优化目标只需要 q(xt1|xt,x0)pθ(xt1|xt) 两个分布尽可能相似,使用KL散度来度量,则有:

 argminθDKL(q(xt1|xt,x0)||pθ(xt1|xt))=argminθDKL(N(xt1;μq,Σq(t))||N(xt1;μθ,Σq(t)))=argminθ12[log|Σq(t)||Σq(t)|k+tr(Σq(t)1Σq(t))+(μqμθ)TΣq(t)1(μqμθ)]=argminθ12[0k+k+(μqμθ)T(σ2tI)1(μqμθ)]ATA=argminθ12σ2t[||μqμθ||22]μq,μθ=argminθ12σ2t(1ˉαt1σ2tˉαt11ˉαtˉαt)[||ˉε0ϵθ(xt,t)||22]

恰好与DDPM的优化目标一致,所以我们可以直接复用DDPM训练好的模型。

pθ 的采样步骤则为:

xt1=ˉαt1xt1ˉαtϵθ(xt,t)ˉαtx0+1ˉαt1σ2tϵθ(xt,t)xt+σtε

σt=η(1αt)(1ˉαt1)1ˉαt

η=1 时,前向过程为 Markovian ,采样过程变为 DDPM 。

η=0 时,采样过程为确定过程,此时的模型 称为 隐概率模型(implicit probabilstic model)。

DDIM如何加速采样:
在 DDPM 中,基于马尔可夫链 tt1 是相邻关系,例如 t=100t1=99
在 DDIM 中,tt1 只表示前后关系,例如 t=100 时,t1 可以是 90 也可以是 80、70,只需保证 t1<t 即可。
此时构建的采样子序列 τ=[τi,τi1,,τ1][t,t1,,1]
例如,原序列 T=[100,99,98,,1],采样子序列为 τ=[100,90,80,,1]

DDIM 采样公式为:

xτi1=ˉατi1xτi1ˉατiϵθ(xτi,τi)ˉατi+1ˉατi1σ2τiϵθ(xτi,τi)+στiε

η=0 时,DDIM 采样公式为:

xτi1=ˉατi1ˉατixτi+(1ˉατi1ˉατi1ˉατi1ˉατi)ϵθ(xτi,τi)

代码实现

训练过程与 DDPM 一致,代码参考上一篇文章。采样代码如下:

device = 'cuda'
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval()

image_size=96
epochs = 500
batch_size = 128
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000])

# 每隔20采样一次
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003


# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod)

def get_val_by_index(val, t, x_shape):
    batch_t = t.shape[0]
    out = val.gather(-1, t)
    return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1])

def p_sample_ddim(model):
    def step_denoise(model, x_tau_i, tau_i, tau_i_1):
        sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
        sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)

        denoise = model(x_tau_i, tau_i)
        
        if eta == 0:
            sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
            sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
            x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
                + (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
                * denoise            
            return x_tau_i_1

        sigma = eta * torch.sqrt((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
        (1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape))
        
        noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device)
        
        # 整个式子由三部分组成
        c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)  
        c2 = torch.sqrt(1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
        c3 = sigma * noise_z
        x_tau_i_1 = c1 + c2 + c3

        return x_tau_i_1

    
    img_pred = torch.randn((4, 3, image_size, image_size), device=device)

    for k in range(0, len(tau_index)):
        # print(tau_index)
        # 因为 tau_index 是倒序的,tau_i = k, tau_i_1 = k+1,这里不能弄反
        tau_i_1 = torch.tensor([tau_index[k+1]], device=device, dtype=torch.long)
        tau_i = torch.tensor([tau_index[k]], device=device, dtype=torch.long)
        img_pred = step_denoise(model, img_pred, tau_i, tau_i_1)

        torch.cuda.empty_cache()
        if tau_index[k+1] == 0: return img_pred

    return img_pred

with torch.no_grad():
    img = p_sample_ddim(model)
    img = torch.clamp(img, -1.0, 1.0)

show_img_batch(img.detach().cpu())

DDIM
https://arxiv.org/pdf/2010.02502
https://github.com/ermongroup/ddim

posted @   gaobowen  阅读(522)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 因为Apifox不支持离线,我果断选择了Apipost!
· 通过 API 将Deepseek响应流式内容输出到前端
点击右上角即可分享
微信分享提示