Denoising Diffusion Probabilistic Models (DDPM)

Ho J., Jain A. and Abbeel P. Denoising diffusion probabilistic models. In Advances in Neural Information Processing Systems (NIPS), 2020.

[Page E. Approximating to the cumulative normal function and its inverse for use on a pocket calculator. Applied Statistics, vol. 26, pp. 75-76, 1977.]

Yerukala R., Boiroju N. K. Approximations to standard normal distribution function. Journal of Scientific and Engineering Research, vol. 6, pp. 515-518, 2015.

diffusion model和变分界的结合.
对抗鲁棒性上已经有多篇论文用DDPM生成的数据用于训练了, 可见其强大.

主要内容

Diffusion models

reverse process

p(xT)=N(xT;0,I)出发:

pθ(x0:T):=p(XT)t=1Tpθ(xt1|xt),pθ(xt1|xt):=N(xt1;μθ(xt,t),Σθ(xt,t)),

注意这个过程我们拟合均值μθ和协方差矩阵Σθ.

这部分的过程逐步将噪声'恢复'为图片(信号)x0.

forward process

q(x1:T|x0):=t=1Tq(xt|xt1),q(xt|xt1):=N(xt;1βtxt1,βtI).

其中βt是可训练的参数或者人为给定的超参数.

这部分为将图片(信号)逐步添加噪声的过程.

变分界

对于参数θ, 很自然地我们希望通过最小化其负对数似然来优化:

Epdata(x0)[logpθ(x0)]=Epdata(x0)[logpθ(x0:T)dx0:T]=Epdata(x0)[logq(x1:T|x0)pθ(x0:T)q(x1:T|x0)dx0:T]=Epdata(x0)[logEq(x1:T|x0)pθ(x0:T)q(x1:T|x0)]Epdata(x0)Eq(x1:T|x0)[logpθ(x0:T)q(x1:T|x0)]=Eq[logpθ(x0:T)q(x1:T|x0)]=Eq[logp(xT)+t=1Tlogpθ(xt1|xt)q(xt|xt1)]=Eq[logp(xT)+t=2Tlogpθ(xt1|xt)q(xt|xt1)+logpθ(x0|x1)q(x1|x0)]=Eq[logp(xT)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)q(xt1|x0)q(xt|x0)+logpθ(x0|x1)q(x1|x0)]=Eq[logp(xT)q(xT|x0)+t=2Tlogpθ(xt1|xt)q(xt1|xt,x0)+logpθ(x0|x1)]

注: q=q(x1:T|x0)pdata(x0), 下面另q(x0):=pdata(x0).

Eq[logq(xT|x0)p(xT)]=q(x0,xT)logq(xT|x0)p(xT)dx0dxT=q(x0)q(xT|x0)logq(xT|x0)p(xT)dx0dxT=q(x0)DKL(q(xT|x0)p(xT))dx0=q(x0:T)DKL(q(xT|x0)p(xT))dx0:T=Eq[DKL(q(xT|x0)p(xT))].

Eq[logq(xt1|xt,x0)pθ(xt1|xt)]=q(x0,xt1,xt)logq(xt1|xt,x0)pθ(xt1|xt)dx0dxt1dxt=q(x0,xt)DKL(q(xt1|xt,x0)pθ(xt1|xt))dx0dxt=Eq[DKL(q(xt1|xt,x0)pθ(xt1|xt))].

故最后:

L:=Eq[DKL(q(xT|x0)p(xT))LT+t=2TDKL(q(xt1|xt,x0)pθ(xt1|xt))Lt1logpθ(x0|x1)L0.]

损失求解

因为无论forward, 还是 reverse process都是基于高斯分布的, 我们可以显示求解上面的各项损失:

首先, 对于forward process中的xt:

xt=1βtxt1+βtϵ,ϵN(0,I)=1βt(1βt1xt2+βt1ϵ)+βϵ=1βt1βt1xt2+1βtβt1ϵ+βϵ=1βt1βt1xt2+1(1βt)(1βt1)ϵ==(s=1t1βs)x0+1s=1t(1βs)ϵ,

q(xt|x0)=N(xt|α¯tx0,(1α¯t)I),α¯t:=s=1tαs,αs:=1βs.

对于后验分布q(xt1|xt,x0), 我们有

q(xt1|xt,x0)=q(xt|xt1)q(xt1|x0)q(xt|x0)q(xt|xt1)q(xt1|x0)exp{12(1α¯t1)βt[(1α¯t1)xt1βtxt12+βtxt1α¯t1x02]}exp{12(1α¯t1)βt[(1α¯t)xt122(1α¯t1)αtxtTxt12α¯t1βt]}

所以

q(xt1|xt,x0)N(xt1|u~t(xt,x0),β~tI),

其中

u~t(xt,x0):=α¯t1βt1α¯tx0+αt(1α¯t1)1α¯txt,

β~t=1α¯t11α¯tβt.

Lt

LTθ无关, 舍去.

作者假设Σθ(xt,t)=σt2I非训练的参数, 其中

σt2=βt|β~t,

分别为x0N(0,I)x0为固定值时, 期望下KL散度的最优参数(作者说在实验中二者差不多).

Lt=12σt2μθ(xt,t)u~t(xt,x0)2+C,t=1,2,,T1.

xt=α¯tx0+1α¯tϵx0=1α¯txt1α¯tα¯tϵ.

所以

Eq[Lt1C]=Ex0,ϵ{12σt2μθ(xt,t)u~t(xt,(1α¯txt1α¯tα¯tϵ))2}=Ex0,ϵ{12σt2μθ(xt,t)1αt(xtβt1α¯tϵ)}

注: 上式子中xtx0,ϵ决定, 实际上xt=xt(x0,ϵ), 故期望实际上是对xt求期望.

既然如此, 我们不妨直接参数化μθ

μθ(xt,t):=1αt(xtβt1α¯tϵθ(xt,t)),

即直接建模残差ϵ.

此时损失可简化为:

Ex0,ϵ{βt22σt2αt(1α¯t)ϵθ(α¯tx0+1α¯tϵ,t)ϵ2}

这个实际上时denoising score matching.

类似地, 从pθ(xt1|xt)采样则为:

xt1=1αt(xtβt1α¯tϵθ(xt,t))+σtz,zN(0,I),

这是Langevin dynamic的形式(步长和权重有点变化)

注: 这部分见here.

L0

最后我们要处理L0, 这里作者假设x0|x1满足一个离散分布, 首先图片取值于{0,1,2,,255}, 并标准化至[1,1]. 假设

pθ(x0|x1)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dx,δ+(x)={+if x=1,x+1255if x<1.δ(x){if x=1,x1255if x>1.

实际上就是将普通的正态分布划分为:

(,1+1/255],(1+1/255,1+3/255],,(13/255,11/255],(11/255,+)

各取值落在其中之一.
在实际代码编写中, 会遇到高斯函数密度函数估计的问题(直接求是做不到的), 作者选择用下列的估计方式:

Φ(x)12{1+tanh(2/π(1+0.044715x2))}.

这样梯度也就能够回传了.

注: 该估计属于Page.

最后的算法

注: t=1对应L0, t=2,,T对应L1,,LT1.
注: 对于Lt作者省略了开始的系数, 这反而是一种加权.
作者在实际中是采样损失用以训练的.

细节

注意到, 作者的ϵθ(,t)是有显示强调t, 作者在实验中是通过attention中的位置编码实现的, 假设位置编码为P:

  1. t=Linear(ACT(Linear(tP))), 即通过两层的MLP来转换得到time_steps;
  2. 作者用的是U-Net结构, 在每个residual 模块中:

x+=Linear(ACT(t)).

参数
T 1000
βt [0.0001,0.02], 线性增长1,2,,T.
backbone U-Net

注: 作者在实现中还用到了EMA等技巧.

代码

原文代码

lucidrains-denoising-diffusion-pytorch

posted @   馒头and花卷  阅读(3938)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
点击右上角即可分享
微信分享提示