【学习笔记】生成对抗网络

举个简单例子来形象描述一下 Generative adversarial network:MNIST 是手写数字数据集,我们可以训练一个检测器来检测给定图片是 UGC/PGC 还是 AIGC。生成对抗网络在训练检测器的同时训练了生成器。

过程若干轮,每轮都一样:用生成器的原始随即参数生成 m 个图,混上 m 个真实的图,贴上标签给检测器拿去训练;训练完检测器之后可以得到损失函数再训练生成器。


生成AIGC图像骗检测器的原始方法是 VAE,Encoder 输出 σi,mi ,另外找一组 ei ,将 σi×ei+mi 传入 Decoder。目标是最小化函数 ieσi(1+σi)+mi2 (防止σi=0,mi 表示原数据的情况),至于为什么是这个函数这里不会证明


上面题到的骗检测器本质上就是 Maximum Likelihood Estimation,可能可以理解成拟合离散分布或者连续曲线。

定义 Likelihood 为 PG(xi,θ) 取对数之后再把观测值的求和变成求 argmaxθExPdata[logPG(x;θ)]

等价转换为下式(观测点的平均近似估计为全局期望)

=argmaxθxPdata(x)logPG(x;θ)δxxPdata(x)logPdata(x)δx

最后一项之和 Pdata 有关,是另外添上去的。但是这样可以变成 KL 散度(相对熵,离散形式为 DKL(P||Q)=xP(x)log(P(x)Q(x)),给我的感受就是两个随机变量分布的“距离”)

=argminKL(Pdata(x)||PG(x))

但是 θ 很难搞,所以用 Neural Network 训练,把原来的 Gaussian Mixture Model(可以理解为若干正态分布的线性组合) 作为模型输入,输出的东西扔到 Decoder 得到 output

从训练角度来讲就是你把期望值用 decoder 逆向计算出来 output 的期望值再计算损失函数,反向传播。

但是 decoder 逆向计算是很复杂的,于是引入一个 Discriminator(可以被理解为类似 checker 的东西)来决断搞出来的分布是不是好使。此时可以形成了有向图结构且处处可微,那么可以反向传播。

至此我们有一个 Discriminator 一个 Generateor ,目的是训练一个 Generator 出来。 形式化的讲,我们要求出 G=argminG(maxDV(G,D)),其中 V(G,D)=ExPdata[logD(x)]+ExPG[log(1D(x))] ,我们将证明 V(G,D) 可以作为损失函数在网络中发挥作用。

G 固定时,将 V(G,D) 写成关于 D 的函数:

V(G,D)=xPdata(x)logD(x)+PG(x)log[1D(x)]argminDV(G,D)=Pdata(x)Pdata(x)+PG(x)

D 加一个 sigmoid 就可以实现 0argminDV(G,D)1

回代并展开得到

maxV(G,D)=xKL(Pdata||(PG(x)+Pdata(x)))+KL(PG||(PG(x)+Pdata(x)))

另外引入 JSD(P||Q)=12(KL(P||P+Q2)+KL(Q||P+Q2)) (不难发现这是一个 P,Q 地位相同的函数),于是上式转化成

=log4+2JSD(Pdata(x)||PG(x))

根据 JSD 函数的性质,如果 PG(x)Pdata(x)=0 则结果是 log2,如果 PG=Pdata 那么结果是 0 其它情况下 0<JSD(Pdata||PG)<log2

那么这时候全局最优解唯一并以常量形式存在且可以被取到(PG=Pdata),所以我们的网络训练可以把最小化 V(G,D) 作为目标。下图还证明了该梯度下降可以收敛到最优解

image

更新网络的方式为 θGi+1θGiηδV(Gi,Di+1)δGi,这时候有个问题就是 Gi+1Gi 的最大值取值点发生了变化,不一定能满足 argmaxDV(Gi+1,D)<argmaxDV(Gi,D) 。不过普通网络梯度下降的时候也会遇到损失函数不减反增的问题,不妨先暂时搁置

注意到上面的计算中我们将 PdataPG 视为了连续函数,在实现中我们还是用观测点来近似全局期望,形式化的就是从 PdataPG 中分别选择 m 个点 (x1xm),(x1,xm)maximize V=1m(i=1mlogD(xi)+log[1D(xi)])

注意这时候我们找到的不是真实的 maxV(G,D) 而是一个 V(G,D) 的下界

总结来说算法流程本质上是下图:

image

实际代码书写的时候说要把 V(G,D) 改写成 1mi=1mlog(xi) ,原因是保证在 xi 比较小的时候(接近于 0 )梯度会更大一些。

另外显然的问题就是 Loss function 可能会趋近于 0 。因为选 m 个点有多少能选到交集里面是很不确定的(或者说概率很小的),所以可以大幅削弱 discriminator 来避免 overfitting;或者说向 discriminatior 的输入中加入额外的人工噪声/给标签加噪声(remember,noise decay over time)。

还有一个问题是ModeCollapse(你不知到模型不能做什么),直观的:

image


据说还有 Conditional GAN/WGAN。别急,让我学了再更

posted @   没学完四大礼包不改名  阅读(86)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示