接着扩散模型 简述训练扩散模型过程中用到的损失函数形式。完整的观察数据x x 的对数似然如下:
l o g p ( x ) ≥ E q ϕ ( z 1 : T | z 0 ) l o g p ( z T ) ∏ T − 1 t = 0 p θ ( z t | z t + 1 ) ∏ T − 1 t = 0 q ϕ ( z t + 1 | z t ) = E q ϕ ( z 1 | z 0 ) [ l o g p θ ( z 0 | z 1 ) ] − D K L ( q ϕ ( z T | z 0 ) | | p ( z T ) ) − T ∑ t = 2 E q ϕ ( z t | z 0 ) [ D K L ( q ϕ ( z t − 1 | z t , z 0 ) | | p θ ( z t − 1 | z t ) ) ] (1) (1) l o g p ( x ) ≥ E q ϕ ( z 1 : T | z 0 ) l o g p ( z T ) ∏ t = 0 T − 1 p θ ( z t | z t + 1 ) ∏ t = 0 T − 1 q ϕ ( z t + 1 | z t ) = E q ϕ ( z 1 | z 0 ) [ l o g p θ ( z 0 | z 1 ) ] − D K L ( q ϕ ( z T | z 0 ) | | p ( z T ) ) − ∑ t = 2 T E q ϕ ( z t | z 0 ) [ D K L ( q ϕ ( z t − 1 | z t , z 0 ) | | p θ ( z t − 1 | z t ) ) ]
其中,q ϕ ( z t − 1 | z t , z 0 ) q ϕ ( z t − 1 | z t , z 0 ) 为了便于计算,已经近似为高斯分布
N ( μ q ( z t , z 0 ) , Σ q ( t ) ) (2) (2) N ( μ q ( z t , z 0 ) , Σ q ( t ) )
μ q ( z t , z 0 ) = α t ( 1 − ¯ α 2 t − 1 ) z t + ¯ α t − 1 ( 1 − α 2 t ) z 0 1 − ¯ α 2 t (3) (3) μ q ( z t , z 0 ) = α t ( 1 − α ¯ t − 1 2 ) z t + α ¯ t − 1 ( 1 − α t 2 ) z 0 1 − α ¯ t 2
Σ q ( t ) = ( 1 − α 2 t ) ( 1 − ¯ α 2 t − 1 ) 1 − ¯ α 2 t I (4) (4) Σ q ( t ) = ( 1 − α t 2 ) ( 1 − α ¯ t − 1 2 ) 1 − α ¯ t 2 I
形式一
为了使得去噪过程p θ ( z t − 1 | z t ) p θ ( z t − 1 | z t ) 和“真实”的q ϕ ( z t − 1 | z t , z 0 ) q ϕ ( z t − 1 | z t , z 0 ) 尽可能接近,因此也可以将p θ ( z t − 1 | z t ) p θ ( z t − 1 | z t ) 建模为一个高斯分布。又由于所有的α α 项在每个时间步都是固定的,因此可以将其方差设计与“真实”的q ( z t − 1 | z t , z 0 ) q ( z t − 1 | z t , z 0 ) 的方差是一样的。且这个高斯分布与初始值z 0 z 0 是无关的,因此可以将其均值设计为关于z t , t z t , t 的函数,即设为μ θ ( z t , t ) μ θ ( z t , t ) .
考虑两个高斯分布的KL散度等于
D K L ( N ( x ; μ x , Σ x ) | | N ( y ; μ y , Σ y ) ) = 1 2 [ l o g | Σ y | | Σ x | − d + t r ( Σ − 1 y Σ x ) + ( μ y − μ x ) T Σ − 1 y ( μ y − μ x ) ] (5) (5) D K L ( N ( x ; μ x , Σ x ) | | N ( y ; μ y , Σ y ) ) = 1 2 [ l o g | Σ y | | Σ x | − d + t r ( Σ y − 1 Σ x ) + ( μ y − μ x ) T Σ y − 1 ( μ y − μ x ) ]
应用到公式(1)中的第三项,因此有
D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ 2 q ( t ) | | μ θ ( x t , t ) − μ q ( x t , x 0 ) | | 2 (6) (6) D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ q 2 ( t ) | | μ θ ( x t , t ) − μ q ( x t , x 0 ) | | 2
其中σ 2 q ( t ) σ q 2 ( t ) 是公式(4)前的系数即σ 2 q ( t ) = ( 1 − α 2 t ) ( 1 − ¯ α 2 t − 1 ) 1 − ¯ α 2 t σ q 2 ( t ) = ( 1 − α t 2 ) ( 1 − α ¯ t − 1 2 ) 1 − α ¯ t 2
由于μ θ ( x t , t ) μ θ ( x t , t ) 也是x t x t 的函数,因此,可以参考公式(3)的形式,将进一步假设
μ θ ( x t , t ) = α t ( 1 − ¯ α 2 t − 1 ) z t + ¯ α t − 1 ( 1 − α 2 t ) z θ ( z t , t ) 1 − ¯ α 2 t (7) (7) μ θ ( x t , t ) = α t ( 1 − α ¯ t − 1 2 ) z t + α ¯ t − 1 ( 1 − α t 2 ) z θ ( z t , t ) 1 − α ¯ t 2
这样公式(6)进一步化简为
D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ 2 q ( t ) ¯ α 2 t − 1 ( 1 − α 2 t ) 2 ( 1 − ¯ α 2 t ) 2 | | z θ ( z t , t ) − z 0 | | 2 (8) (8) D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ q 2 ( t ) α ¯ t − 1 2 ( 1 − α t 2 ) 2 ( 1 − α ¯ t 2 ) 2 | | z θ ( z t , t ) − z 0 | | 2
至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出其原来的样本。最终最小化公式(1)中的第三项,等价于最小化关于时间步的期望,因此有
a r g m i n E t ∼ U { 2 , T } [ E q ϕ ( z t | z 0 ) [ D K L ( q ϕ ( z t − 1 | z t , z 0 ) | | p θ ( z t − 1 | z t ) ) ] ] a r g m i n E t ∼ U { 2 , T } [ E q ϕ ( z t | z 0 ) [ D K L ( q ϕ ( z t − 1 | z t , z 0 ) | | p θ ( z t − 1 | z t ) ) ] ]
形式二
由
z t = ¯ α t z 0 + √ 1 − ¯ α 2 t ¯ ϵ t (9) (9) z t = α ¯ t z 0 + 1 − α ¯ t 2 ϵ ¯ t
可得
z 0 = z t − √ ( 1 − ¯ α 2 t ) ¯ ϵ t ¯ α t (10) (10) z 0 = z t − ( 1 − α ¯ t 2 ) ϵ ¯ t α ¯ t
再代入公式(3)得
μ q ( x t , x 0 ) = 1 α t x t − 1 − α 2 t √ 1 − ¯ α 2 t α t ¯ ϵ t (11) (11) μ q ( x t , x 0 ) = 1 α t x t − 1 − α t 2 1 − α ¯ t 2 α t ϵ ¯ t
参考形式一中的假设方式,可以假设
μ θ ( x t , t ) = 1 α t x t − 1 − α 2 t √ 1 − ¯ α 2 t α t ϵ θ ( z t , t ) (12) (12) μ θ ( x t , t ) = 1 α t x t − 1 − α t 2 1 − α ¯ t 2 α t ϵ θ ( z t , t )
再代入公式(6)可以得到
D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ 2 q ( t ) ( 1 − α 2 t ) 2 ( 1 − ¯ α 2 t ) α 2 t | | ϵ θ ( z t , t ) − ϵ t | | 2 (12) (12) D K L ( N ( z t − 1 ; μ q ( z t , z 0 ) , Σ q ( t ) ) | | N ( z t − 1 ; μ θ ( z t , t ) , Σ q ( t ) ) ) = 1 2 σ q 2 ( t ) ( 1 − α t 2 ) 2 ( 1 − α ¯ t 2 ) α t 2 | | ϵ θ ( z t , t ) − ϵ t | | 2
至此,优化VDM就变成了学习一个神经网络,从样本任意时刻的加噪版本预测出按照公式(10)添加的原始噪音。
形式三
由公式(8)和公式(12)可以得到
| | ϵ θ ( z t , t ) − ϵ t | | 2 = ¯ α t 2 1 − ¯ α t 2 | | z θ ( z t , t ) − z 0 | | 2 (13) (13) | | ϵ θ ( z t , t ) − ϵ t | | 2 = α t ¯ 2 1 − α t ¯ 2 | | z θ ( z t , t ) − z 0 | | 2
由于¯ α t , √ 1 − ¯ α t 2 α t ¯ , 1 − α t ¯ 2 分别是t t 时间步的加噪信号公式(9)中的原始信号和噪音信号系数,因此将信噪比SNR(t)定义为系数平方之比,即
S N R ( t ) = ¯ α t 2 1 − ¯ α t 2 (14) (14) S N R ( t ) = α t ¯ 2 1 − α t ¯ 2
这个信噪比在时间步初期其值较大,代表真实信号占比多噪音占比少;在时间步后期其值较小,代表真实信号占比少噪音占比多。因为推理过程是完全从高斯分布随机取样,为了保证推理与训练保持一致,训练过程采取特定的¯ α t α ¯ t 使得T步得到的是完全噪音,不包含任何原始信号。此时信噪比是0.
当预测发送在信噪比接近0(¯ α t → 0 α ¯ t → 0 )时,模型原始预测是噪音¯ ϵ ϵ ¯ ,因此根据公式(10)预估对应的原始信号
¯ z 0 = z t − √ ( 1 − ¯ α 2 t ) ¯ ϵ ¯ α t z ¯ 0 = z t − ( 1 − α ¯ t 2 ) ϵ ¯ α ¯ t
这样网络预测的微小差异就会被放大很多倍,因此在论文[3]模型蒸馏过程,这就不是一个稳定的设计。为了避免这个问题,作者提出了3种解决办法。
直接预测z z ,而非噪音ϵ ϵ
同时预测z , ϵ z , ϵ ,通过两个独立的输出通道z , ϵ z , ϵ 。由于根据公式(10)可以再由ϵ ϵ 再推断出z ′ z ′ ,然后可以根据¯ α 2 t , 1 − ¯ α 2 t α ¯ t 2 , 1 − α ¯ t 2 对这两个值进行差值。
预测混合体 v = α t ϵ − √ 1 − α 2 t z v = α t ϵ − 1 − α t 2 z
参考
[1]. https://www.cnblogs.com/wolfling/p/17938102
[2]. Understanding Diffusion Models: A Unified Perspective
[3]. Progressive Distillation for Fast Sampling of Diffusion Models
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律