GAN

生成式对抗网络(GAN)

一、什么是生成式对抗网络GAN?

在知乎上看到一个比较有趣的例子:

女生让男生给自己拍照,可是一直不满意男生拍的照片,就对照“别人家的男朋友”拍的照片,一次次让男生去改,直到女生满意。

在这个例子中,

  • 男生可以被看作是GAN中的生成模型(Generative Model);

  • 女生可以被看作是GAN中的判别模型(Discriminator);

  • 整个拍照的过程可以被看作是博弈式的训练过程

  • 男生(生成模型)的目的:拍出女朋友满意的照片(生成一幅和真实图片极其相似的图片)

  • 女生(判别模型)的目的:分辨男朋友拍的照片,不满意的打回去(判别生成图片与真实图片是否相似,如果不够相似,打回去)

上述博弈过程,如果采用神经网络作为模型类型,则被称为生成式对抗网络(GAN)

正如视频中提到的两个问题:

  • 为什么罪犯制造的假币越来越逼真?

    为什么GAN可以生成数据?

二、GAN的详细介绍

GAN的框架

判别器D(Discriminator):区分真实样本和虚假样本。D是一个神经网络,经过运算后,如果是真实的图片,给出real(1);如果是假的图片,给出fake(0)

随机噪声z:从一个先验分布(人为定义,一般是均匀分布或者正态分布)中随机采样的向量

真实样本x:从数据库中采样的样本

合成样本G(z):生成模型G输出的样本

生成器G(Generator):欺骗判别器。生成虚假数据,使得判别器D能够尽可能给出高的评分。生成器不断改变自己,直到生成的很多图片能够欺骗判别器

GAN目标函数

训练算法:

1.随机初始化生成器和判别器

2.交替训练判别器D和生成器G,直到收敛

  • 步骤一:固定生成器G(不优化),训练判别器D区分真实图像与合成图像(赋予真实图像高分,赋予合成图像低分)(用监督训练二分类问题)
  • 步骤二:固定判别器D,训练生成器G欺骗判别器D(更新生成器的参数,使其合成的图片被生成器D赋予高分)(最大化问题)

训练一个生成模型

一个能够生成我们想要的数据的模型(图模型、函数、神经网络)

GAN通过一个低维向量 生成器(全连接神经网络)

cGAN生成可控的数据 生成器(全连接神经网络)

DCGAN 生成器(卷积神经网络)

WGAN 生成器(WGAN)重新设计目标函数,训练更稳定,生成数据质量更棒

KL散度和JS散度

  • KL散度(Kullback-Leibler divergence)

    一种衡量两个概率分布的匹配程度的指标,又称为KL距离,相对熵

当P(x)和Q(x)的相似度越高,KL散度越小

KL散度主要有两个性质:

(1)不对称性

(2)非负性

KL散度本质是用来衡量两个概率分布的差异一种数学计算方式;由于用到比值除法不具备对称性。

神经网络训练时为何不用KL散度,从数学上来讲,KL散度多减了一个H(P);P代表真实分布,Q代表估计的分布

极大似然估计等价于最小化生成数据分布和真实分布的KL散度

  • JS散度(Jensen-Shannon divergence)

    JS散度也称为JS距离,是KL散度的一种变形

JS散度主要性质:

(1)值域范围(JS散度的值域范围是[0,1],相同是0,相反为1)

(2)对称性

(3)交叉熵

很多情况下,假设数据符合高斯分布是不合理的,数据分布是无法用公式显示的写出来的

因此用高斯模型去拟合数据分布,我们需要一个更通用的生成模型,可以拟合任意数据分布,如下

GAN:生成式对抗网络通过对抗训练,间接计算出散度JS,使得模型可以优化

GAN做的事情:

1.最大化判别器损失,等价于计算合成数据分布和真实数据分布的JS散度

2.最小化生成器损失,等价于最小化JS散度(也就是优化生成模型 )

三、DCGAN

四、代码练习

(一)GAN

  1. 通过make_moons生成双半月形的数据,同时把数据点画出来

  1. 定义生成器、判别器、优化器

    判别器中使用了sigmoid函数(可能是因为需要判别生成的图片是否是真实图片,即相当于是一个二分类的问题,因此用sigmoid函数)

    优化器选择的是adam

  2. 对抗训练

    整个对抗训练可以分为两部分:

    • 第一部分(固定生成器G,改进判别器D)
    • 第二部分(固定判别器D,改进生成器G)
  3. 修改learning_rate和batch_size

    学习率为0.0001,batch_size为50的结果:

学习率为0.001,batch_size为250的结果:

可以明显看出随着batch_size的增大、loss的减小,效果明显改善。

(个人猜测:增大batch_size的值后,能够一次性处理更多的数据,从而能够更好地把握大方向,训练的波动程度更小)

(二)CGAN(条件生成-对抗网络)

  • 对比于GAN,CGAN在生成器以及判别器上都多了一个标签作为输入
  • 生成器的输入是噪声和标签,输出是生成图
  • 判别器的输入是生成图,真实图以及标签,输出是真和假

步骤与GAN相似,不同的是在生成器和判别器的定义中加入了10维的标签信息

全连接判别器:

全连接生成器:

epoch改为100后:

在epoch为100时,辨别器的损失为0.00030,效果不太好

(三)DCGAN(深度卷积对抗网络)

  • 对比于GAN,在判别器和生成器中使用了卷积结构(在第二个、第三个、第四个滑动卷积层中使用BN加快网络收敛),同样添加Sigmoid激活函数

滑动卷积判别器:

反滑动卷积生成器:

  • 第一层:把输入线性变换成256×4×4的矩阵,并在这个基础上做反卷积
  • 第四层:不使用BN,使用tanh激活函数

epoch为30时,结果如下:

在epoch改为100后,效果不如epoch为30的结果(不想明白什么原因)

posted @ 2020-09-12 18:27  陳半仙  阅读(328)  评论(0编辑  收藏  举报