DDMP中的损失函数

接着扩散模型 简述训练扩散模型过程中用到的损失函数形式。完整的观察数据x的对数似然如下:

(1)log p(x)Eqϕ(z1:T|z0)logp(zT)t=0T1pθ(zt|zt+1)t=0T1qϕ(zt+1|zt)=Eqϕ(z1|z0)[log pθ(z0|z1)]DKL(qϕ(zT|z0)||p(zT))t=2TEqϕ(zt|z0)[DKL(qϕ(zt1|zt,z0)||pθ(zt1|zt))]

其中,qϕ(zt1|zt,z0)为了便于计算,已经近似为高斯分布

(2)N(μq(zt,z0),Σq(t))

(3)μq(zt,z0)=αt(1α¯t12)zt+α¯t1(1αt2)z01α¯t2

(4)Σq(t)=(1αt2)(1α¯t12)1α¯t2I

形式一

为了使得去噪过程pθ(zt1|zt)和“真实”的qϕ(zt1|zt,z0)尽可能接近,因此也可以将pθ(zt1|zt)建模为一个高斯分布。又由于所有的α项在每个时间步都是固定的,因此可以将其方差设计与“真实”的q(zt1|zt,z0)的方差是一样的。且这个高斯分布与初始值z0是无关的,因此可以将其均值设计为关于zt,t的函数,即设为μθ(zt,t).

  考虑两个高斯分布的KL散度等于

(5)    DKL(N(x;μx,Σx)||N(y;μy,Σy))=12[log|Σy||Σx|d+tr(Σy1Σx)+(μyμx)TΣy1(μyμx)]

应用到公式(1)中的第三项,因此有

(6)    DKL(N(zt1;μq(zt,z0),Σq(t))||N(zt1;μθ(zt,t),Σq(t)))=12σq2(t)||μθ(xt,t)μq(xt,x0)||2

其中σq2(t)是公式(4)前的系数即σq2(t)=(1αt2)(1α¯t12)1α¯t2

由于μθ(xt,t)也是xt的函数,因此,可以参考公式(3)的形式,将进一步假设

(7)μθ(xt,t)=αt(1α¯t12)zt+α¯t1(1αt2)zθ(zt,t)1α¯t2

这样公式(6)进一步化简为

(8)    DKL(N(zt1;μq(zt,z0),Σq(t))||N(zt1;μθ(zt,t),Σq(t)))=12σq2(t)α¯t12(1αt2)2(1α¯t2)2||zθ(zt,t)z0||2

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出其原来的样本。最终最小化公式(1)中的第三项,等价于最小化关于时间步的期望,因此有

argminEtU{2,T}[Eqϕ(zt|z0)[DKL(qϕ(zt1|zt,z0)||pθ(zt1|zt))]]

形式二

(9)zt=α¯tz0+1α¯t2ϵ¯t

可得

(10)z0=zt(1α¯t2)ϵ¯tα¯t

再代入公式(3)得

(11)μq(xt,x0)=1αtxt1αt21α¯t2αtϵ¯t

参考形式一中的假设方式,可以假设

(12)μθ(xt,t)=1αtxt1αt21α¯t2αtϵθ(zt,t)

再代入公式(6)可以得到

(12)    DKL(N(zt1;μq(zt,z0),Σq(t))||N(zt1;μθ(zt,t),Σq(t)))=12σq2(t)(1αt2)2(1α¯t2)αt2||ϵθ(zt,t)ϵt||2

至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出按照公式(10)添加的原始噪音。

形式三

由公式(8)和公式(12)可以得到

(13)||ϵθ(zt,t)ϵt||2=αt¯21αt¯2||zθ(zt,t)z0||2

由于αt¯,1αt¯2 分别是t时间步的加噪信号公式(9)中的原始信号和噪音信号系数,因此将信噪比SNR(t)定义为系数平方之比,即

(14)SNR(t)=αt¯21αt¯2

这个信噪比在时间步初期其值较大,代表真实信号占比多噪音占比少;在时间步后期其值较小,代表真实信号占比少噪音占比多。因为推理过程是完全从高斯分布随机取样,为了保证推理与训练保持一致,训练过程采取特定的α¯t使得T步得到的是完全噪音,不包含任何原始信号。此时信噪比是0.

当预测发送在信噪比接近0(α¯t0)时,模型原始预测是噪音ϵ¯,因此根据公式(10)预估对应的原始信号

z¯0=zt(1α¯t2)ϵ¯α¯t

这样网络预测的微小差异就会被放大很多倍,因此在论文[3]模型蒸馏过程,这就不是一个稳定的设计。为了避免这个问题,作者提出了3种解决办法。

  • 直接预测z,而非噪音ϵ
  • 同时预测z,ϵ,通过两个独立的输出通道z,ϵ。由于根据公式(10)可以再由ϵ再推断出z,然后可以根据α¯t2,1α¯t2对这两个值进行差值。
  • 预测混合体 v=αtϵ1αt2z

参考

[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models

posted @   星辰大海,绿色星球  阅读(158)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示