三、为什么扩散模型使用均方误差损失(选看)

高能预警:这篇文章难度很大,包含很多的数学推导,如果不想接触太多的数学内容,那么可以跳过不看。

看这篇文章之前,你需要了解:什么是马尔科夫链,什么是极大似然估计,什么是KL散度,两个正态分布的KL散度,什么是贝叶斯公式

以下内容参考了主要参考了博客What are Diffusion Models? 以及李宏毅老师的课程

1. 马尔科夫链与pθ(x)

本节推导得出的结论:

  • q(x1:Tx0)=t=1Tq(xtxt1)p(x0:T)=p(xT)t=1Tpθ(xt1|xt)
  • pθ(x0:T)=p(xT)t=1Tpθ(xt1|xt)

在扩散模型中,为了方便计算,我们假设前向过程中的图片x0,x1,xT构成一个马尔科夫链,并将前向过程中图片x的概率分布记作q(x)

因此,我们有

q(x1:Tx0)=t=1Tq(xtxt1)

同时,我们令pθ(x)表示:在反向过程中,模型生成图片x的概率。

因此,在对扩散模型使用极大似然估计时,样本是没有噪音的图片x0,似然函数pθ(x0)表示模型最终生成x0的概率。自然的,极大似然估计的目标是找到使得pθ(x0)最大的模型。

注意到在反向过程中,xT是噪音图片,直接采样自标准正态分布,并不需要通过模型生成,pθ(xT)和模型选取无关,因此可以记作p(xT)

由于x0,x1,xT构成一个马尔科夫链,因此

pθ(x0:T)=p(xT)t=1Tpθ(xt1|xt)

2. 极大似然估计

本节推导得出的结论:minlogpθ(x0)等价于minLT+LT1++L0,其中

LT=DKL(q(xTx0)pθ(xT))Lt1=DKL(q(xt1xt,x0)pθ(xt1xt)) for 2tTL0=logpθ(x0x1)

上文中,我们说到,极大似然估计的目标是maxpθ(x0),为了方便起见,可以将目标转换为minlogpθ(x0)

我们对logpθ(x0)进行一些变形,得到

logpθ(x0)logpθ(x0)+DKL(q(x1:Tx0)pθ(x1:Tx0))=logpθ(x0)+Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)+logpθ(x0)]=Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)]

(1)logpθ(x0)Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)]

其中,DKL(q||pθ)表示分布q和分布pθ的KL散度;期望Eq(x1:Tx0)(f)=q(x1:Tx0)×f dx1:T


下面,我们对公式(1)左右两侧同时取期望

logp(x0)q(x0)dx0Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)]q(x0)dx0Eq(x0)logpθ(x0)[logq(x1:Tx0)pθ(x0:T)]q(x1:Tx0)q(x0)dx1:Tdx0=[logq(x1:Tx0)pθ(x0:T)]q(x0:T)dx0:T=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]

(2)Eq(x0)logpθ(x0)Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]

为了方便表示,我们将Eq(x0)logpθ(x0)记作LCE,将Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]记作LVLB

minlogpθ(x0)等价于minLCE。而只要minLVLB,就会minLCE

因此,minlogp(x0)的问题就转换为了minLVLB的问题。


下面对LVLB进行变形

LVLB=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=Eq[logt=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt)]=Eq[logpθ(xT)+t=1Tlogq(xtxt1)pθ(xt1xt)]=Eq[logp(xT)+t=2Tlogq(xtxt1)pθ(xt1xt)+logq(x1x0)pθ(x0x1)]=Eq[logp(xT)+t=2Tlog(q(xt1xt,x0)pθ(xt1xt)q(xtx0)q(xt1x0))+logq(x1x0)pθ(x0x1)]=Eq[logp(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+t=2Tlogq(xtx0)q(xt1x0)+logq(x1x0)pθ(x0x1)]=Eq[logp(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+logq(xTx0)q(x1x0)+logq(x1x0)pθ(x0x1)]=Eq[logq(xTx0)p(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)logpθ(x0x1)]=Eq[DKL(q(xTx0)p(xT))LT+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0]

对于一个期望来说,如果我们使它的每一项都最小化,那么期望的值也会最小化,因此有

(3)minLVLBminLT+LT1++L0

其中

LT=DKL(q(xTx0)p(xT))Lt1=DKL(q(xt1xt,x0)pθ(xt1xt)) for 2tTL0=logpθ(x0x1)


注意到,对于LT而言,其中q(xTx0)p(xT)的取值均与参数θ无关,因此LT可以看成常数,我们只需要最小化LtL0即可。

又因为minDKL(q(x0|x1,x0)||pθ(x0|x1))minDKL(1||pθ(x0|x1))minlogpθ(x0x1),因此,可以将L0转换为Lt1的形式,那么只需要最小化Lt1即可。

3. Lt1中的q(xt1xt,x0)

本节推导得出的结论:q(xt1xt,x0)=N(xt1;μ~t,β~tI),其中μ~t=1αt(xt1αt1α¯tϵ)β~t=1α¯t11α¯tβt

使用贝叶斯公式,我们可以将q(xt1xt,x0)转换为

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)

又因为x0,x1,xT构成一个马尔科夫链,因此q(xtxt1,x0)=q(xtxt1)

我们在上篇文章的末尾提到过,在前向过程中,概率q(xtxt1)=N(xt;1βtxt1,βtI)q(xtx0)=N(xt;α¯tx0,(1α¯t)I)

于是有

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)exp(12((xtαtxt1)2βt+(xt1α¯t1x0)21α¯t1(xtαtx0)21α¯t))=exp(12(xt22αtxtxt1+αtxt12βt+xt122α¯t1x0xt1+α¯t1x021α¯t1(xtα¯tx0)21α¯t))=exp(12((αtβt+11α¯t1)xt12(2αtβtxt+2α¯t11α¯t1x0)xt1+C(xt,x0)))

其中 C(xt,x0) 是不含 xt1 的常数,因此可以被忽略。

根据正态分布的概率公式,我们可以得到

注意 αt=1βtα¯t=i=1tαi

β~t=1/(αtβt+11α¯t1)=1/(αtα¯t+βtβt(1α¯t1))=1α¯t11α¯tβt

μ~t=(αtβtxt+αt11α¯t1x0)/(αtβt+11α¯t1)=(αtβtxt+α¯t11α¯t1x0)1α¯t11α¯tβt=αt(1α¯t1)1α¯txt+α¯t1βt1α¯tx0

在上一篇文章中,我们得到xt=α¯tx0+1α¯tϵ,因此有

x0=1α¯t(xt1α¯tϵ)

其中ϵ表示从x0xt添加的噪音之和。

我们将μ~t表达式中的x0进行替换,可以得到

μ~t=αt(1α¯t1)1α¯txt+α¯t1βt1α¯t1α¯t(xt1α¯tϵ)=1α(xt1αt1α¯tϵ)

因此,我们有

(4)q(xt1xt,x0)=N(xt1;1αt(xt1αt1α¯tϵ),1α¯t11α¯tβtI)

4. 最小化Lt1

本节推导得出的结论:最小化Lt1等价于最小化ϵϵθ(xt,t)2

其中,ϵ表示从x0xt添加的噪音之和,ϵθ表示预测噪音的模型,模型有两个输入:t时刻的图片xt以及时刻t

在上一小节,我们推出:q(xt1xt,x0)符合正态分布。又由于x0,x1,xT构成一个马尔科夫链,因此q(xt1xt,x0)=q(xt1xt),也就是说q(xt1xt)符合正态分布。

我们的目的是让反向过程尽可能和正向过程一致。因此我们可以合理假设,在反向过程中,pθ(xt1xt)也符合正态分布,并且和q(xt1xt,x0)的分布近似。

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t)),因为σ~t2为常数,因此我们直接令Σθ(xt,t)=σ~t2。同时,我们还要尽可能的令μθ(xt,t)接近μ~t

注意到,μ~t里面唯一一个,在反向过程中不知道的量就是从x0xt添加的噪音之和 ϵ,因此我们可以训练一个模型来预测ϵ

这个模型就是我们在第一篇文章中提到的Noise Predicter,我们将Noise Predicter记作ϵθ,它有两个输入:t时刻的图片xt以及时刻t。模型的预测值记作ϵθ(xt,t)

因此,

μθ(xt,t)=1αt(xt1αt1α¯tϵθ(xt,t))

对于KL散度,我们有以下性质:

若有两个正态分布PQ,均值分别为μ1μ2;方差分别为σ12σ22,且σ12σ22都为常数,那么

minDKL(P||Q)min||μ1μ2||2

因此,

(5)minLt1min||μ~tμθ(xt,t)||2minϵϵθ(xt,t)2

5. 总结

至此,我们完成了使用极大似然估计来推导损失函数的过程。

我们得到的结论是

minlogpθ(x0)等价于minLT+LT1++L0,其中

LT=DKL(q(xTx0)p(xT))Lt1=DKL(q(xt1xt,x0)pθ(xt1xt)) for 2tTL0=logpθ(x0x1)

其中LT可以看作常数,L0可以转换为Lt1的形式,而最小化Lt1又相当于最小化ϵϵθ(xt,t)2

也就是说,我们的目标是

t=1T[ϵϵθ(xt,t)2]

因此我们知道:扩散模型的损失函数就是均方误差损失。

posted @   Brain404  阅读(564)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?
点击右上角即可分享
微信分享提示