Denoising Diffusion Implicit Models(去噪隐式模型)

DDPM有一个很麻烦的问题,就是需要迭代很多步,十分耗时。有人提出了一些方法,比如one-step dm等等。较著名、也比较早的是DDIM。

原文:https://arxiv.org/pdf/2010.02502

参考博文:https://zhuanlan.zhihu.com/p/666552214?utm_id=0

训练过程与ddpm一致,推理过程发生变化,加速了扩散过程,结果也变得稳定一些。

 DDIM假设

 DM假设

ddim给出了一个新的扩散假设,结合ddpm的原假设,直接往新假设代入xt得到:

 根据原假设,联系上式:

得到:

 DDIM假设被变形为:

 x0可以根据扩散模型假设消去,得到:

 当然你可以隔着很多步,所以有:

 

DDIM代码如下:

复制代码
#ddpm 
def sample_backward_step(self, x_t, t, net, simple_var=True,isUnsqueeze=True):

        n = x_t.shape[0]
        if isUnsqueeze:
            t_tensor = torch.tensor([t] * n,
                                dtype=torch.long).to(x_t.device).unsqueeze(1)
        else:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device)
        eps = net(x_t, t_tensor)


        if simple_var:
            var = self.betas[t]
        else:
            var = (1 - self.alpha_bars[t - 1]) / (
                1 - self.alpha_bars[t]) * self.betas[t]
        noise = torch.randn_like(x_t)
        noise *= torch.sqrt(var)

        mean = (x_t -
                (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
                eps) / torch.sqrt(self.alphas[t])
        x_t = mean + noise

        return x_t

#ddim
def time_backward_step(self, x_t, t, net, sample_step=5,isUnsqueeze=True):
        n = x_t.shape[0]
        if isUnsqueeze:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device).unsqueeze(1)
        else:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device)
        eps = net(x_t, t_tensor)

        xstar=torch.sqrt(1./self.alpha_bars[t])*x_t-torch.sqrt(1./self.alpha_bars[t]-1)*eps
        xstar=torch.clamp(xstar,-1,1)

        prev_t=t-sample_step if t-sample_step>0 else 0
        pred_xt=torch.sqrt(1-self.alpha_bars[prev_t])*eps
        x_prev=torch.sqrt(self.alpha_bars[prev_t])*xstar+pred_xt

        return x_prev
复制代码

 DDIM结果图

 DDPM结果图
 
ddim inverse:

 Null-text Inversion for Editing Real Images using Guided Diffusion Models

复制代码
    def ddim_inverse(self, x_t, t, net, label,sample_step=5,w=10,isUnsqueeze=True):

        n = x_t.shape[0]
        if isUnsqueeze:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device).unsqueeze(1)
        else:
            t_tensor = torch.tensor([t] * n,
                                    dtype=torch.long).to(x_t.device)
       
        cat = net(x_t, t_tensor, y=torch.ones_like(label) * 1)
       
       
        eps = (w + 1) * (cat) - w * un

        next_t = t + sample_step if t + sample_step < self.n_steps-1 else self.n_steps-1
        xstar=torch.sqrt(self.alpha_bars[next_t]/self.alpha_bars[t])*(x_t-torch.sqrt(1-self.alpha_bars[t])*eps)

        pred_xt=torch.sqrt(1-self.alpha_bars[next_t])*eps
        x_next=xstar+pred_xt

        return x_next
复制代码

 

posted @   澳大利亚树袋熊  阅读(39)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示