Part3: Dive into DDPM

背景

整个系列有相对完整的公式推导,若正文中有涉及到的省略部分,皆额外整理在Part4,并会在正文中会指明具体位置。

Part2基于Variational Inference,找到原目标函数lnpθ(x0)的上界L,定义如下:

(1)L:=Eq[logp(xT)q(xTx0)t>1logpθ(xt1xt)q(xt1xt,x0)logpθ(x0x1)]=Eq[DKL(q(xTx0)p(xT))LT+t>1DKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0]

沿着论文的思路对L继续精简,得到最终在代码层面实现的损失函数Lsimple。同样的,补充的推导见Part4;“扩散过程”的梗概介绍见Part1

简化过程

不难看出L中的每一项皆为KL散度。回顾forward processreverse process两个阶段的定义,马尔可夫链的状态转移皆服从高斯分布,如下所示:

(2)q(xtxt1):=N(xt;1βtxt1,βtI)pθ(xt1xt):=N(xt1;μθ(xt,t),Σθ(xt,t))

同时,经过推导(见Part4推导二),易知:

(3)q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI) where μ~t(xt,x0):=α¯t1βt1α¯tx0+αt(1α¯t1)1α¯txt,  β~t:=1α¯t11α¯tβt

KL散度比较皆发生在两个Gaussian间。

Lt的简化

可以看到,(1)式中的LT代表前向扩散过程,与待求解的参数项θ无关,因此可被忽略:

argminθ(L)argminθ(Eq[t>1DKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0])

注:在2015年提出diffusion框架的论文,前向扩散过程中的βt是可以被学习的参数,故此处可视作DDPM第一处简化

Lt1的简化

对于反向扩散过程的分布pθ(xt1xt),共涉及到两组参数μθΣθDDPM第二处简化是定义Σ为常数σt2,在计算中使用βtβ~t代替,故pθ(xt1xt)=N(xt1;μθ(xt,t),σt2I)

基于Part4两个高斯的KL散度,对于Lt1,有:

(4)Lt1=Eq[12σt2μ~t(xt,x0)μθ(xt,t)2]+C

其中C是个常数项。

仔细观察(4)不难发现,想要目标函数最小化,则μ~t(xt,x0)μθ(xt,t)间的“距离必须要近”。也就是说,深度网络通过训练,使得μθ(xt,t)趋近于μ~t(xt,x0)为了使训练更加简单,尝试对(4)式改写。

Foward Process中,xt可由x0ϵ表示(见Part4推导一),不妨将xt记作xt(x0,ϵ),故x0可以展开表示为xt(x0,ϵ)ϵ的差:

(5)xt=α¯tx0+1α¯tϵ  x0=1α¯t(xt1α¯tϵ)

又因为(3)式,故有:

(6-1)μ~t(xt,x0)=α¯t1βt1α¯tx0+αt(1α¯t1)1α¯txt=α¯t1βt1α¯txt1α¯tϵα¯t+αt(1α¯t1)1α¯txt=1αt(xtβt1α¯tϵ)

前文提到,要优化(4)式,则必然有:μθ(xt,t)μ~t(xt,x0)。其中,μθ(xt,t)是深度网络的输出(预测)结果,xtt作为模型的输入参数。

(6-1)可知μ~t(xt,x0)能展开为xtϵ的表达,xt已知,那不妨令原本要预测μθ(xt,t)的深度网络直接预测ϵθ(xt,t),变换前后依然等价。即

(6-2)μθ(xt,t)1αt(xtβt1α¯tϵθ(xt,t))

此处以θθ对变换前后的深度网络参数进行区分,故1αt(xtβt1α¯tϵθ(xt,t))需要无限趋近于μ~t(xt,x0)

(6-1)(6-2)代入(4)式,有:

(7)Lt1C=Eq[12σt2μ~t(xt,x0)μθ(xt,t)2]Ex0,ϵ[12σt21αt(xt(x0,ϵ)βt1α¯tϵ)1αt(xt(x0,ϵ)βt1α¯tϵθ(xt,t))2]=Ex0,ϵ[βt22σt2αt(1α¯t)ϵϵθ(α¯tx0+1α¯tϵ,t)2]

对比(4)(7),不难发现参数θ作用的对象发生变化。在(4)中,θ的参数化对象为高斯分布的均值μ;而在(7)中,θ的参数化对象转移到ϵ实际上,不仅可以参数化μϵ,也可以参数化x0,只需要对(5)中表示的主体进行变换即可。

并且,重新审视(6-2),该式与Part1中的采样算法联系上了。上述目标函数的设定及推理,皆是为了获取反向过程的分布pθ(xt1xt):=N(xt1;μθ(xt,t),Σθ(xt,t))

通过公式(6-2),按照反向过程相邻状态间的图像转换服从高斯分布的定义,反向过程中知晓xtt后,通过深度网络预测出ϵ,再基于此求出μt,结合自定义的σt,可采样得到xt1,便实现反向过程的一次“降噪”。

L0的简化

这一项对应着信息由隐变量转变回x0,故而需要特殊考虑。

真实图片中各个像素由0到255的数值组成,在处理时通常将所有像素值归一化到区间[-1,1]。论文中将该项对应的优化目标定义为:

(8)pθ(x0x1)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dxδ+(x)={ if x=1x+1255 if x<1δ(x)={ if x=1x1255 if x>1

其中,积分项是为了与图片真实像素的离散特性保持一致,D为像素点的个数。

该项的优化目标是:对于输入图片x0的所有像素位置,使得基于神经网络产生的高斯分布在该位置的采样结果,与x0对应位置的真实值相差不大。

直接文字阐述并不好理解,下方是对于单个位置的具体实例,截图来自视频
image

当前有一张真实的图片x0,对应上图内靠左边的图片,经过缩放后,在位置i的值为x0i=10255

并且,中间图片表示在x1i(此时还处于有噪声状态),经过神经网络模型,预测出该位置的值是服从均值为11255的高斯分布N1

在左下角画出该N1的概率密度曲线,此时积分的上下界为(9255,11255),从图上可以直观地看出积分对应的阴影面积相对来说比较大。故基于此采样得到的x0^i与输入图片x0i接近的置信度很高。在训练时反映出来的是,神经网络在该位置的预测表现对(8)式即Loss的贡献程度较低;

但如果神经网络预测出该位置的值服从服从均值为105255的高斯分布N2,此时概率密度曲线整体会往右平移,(9255,11255)区域属于长尾位置,显然积分结果比较小,从侧面来说,基于此采样得到的x0^i与输入图片x0i接近的置信度很低,在训练时对Loss的贡献程度高,在反向传播时的梯度也大。

实际代码实现中,该项被省略,这是第三处简化

简化的损失函数

回顾(1)式,目前只剩下以Lt1为主体的求和部分,如下所示:

(9)argminθ(L)argminθEq[t>1DKL(q(xt1xt,x0)pθ(xt1xt))Lt1]argminθEx0,ϵ[βt22σt2αt(1α¯t)ϵϵθ(α¯tx0+1α¯tϵ,t)2]

对于(9)式,DDPM第四处简化在于省略了均方差损失项的权重,最终的损失函数Lsimple为:

Lsimple (θ):=Et,x0,ϵ[ϵϵθ(α¯tx0+1α¯tϵ,t)2]

总结

回顾本文,DDPM在损失函数上做了很多简化,对于代码侧的实现非常友好。同时,论文作者也给出实验对比,验证简化并不会使得结果变差,有些简化(比如设置reverse过程中的Σ为非参数项)甚至取得大幅度的提升效果。

Reference

posted @   小王点点  阅读(79)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
历史上的今天:
2020-02-25 如何利用dokcer提交我的比赛代码
点击右上角即可分享
微信分享提示