GAN(生成对抗网络)以及keras实现

由于笔者水平有限,如有错,欢迎指正。

论文原文:https://arxiv.org/pdf/1406.2661.pdf


0 GAN的思想

GAN,全称为 Generative Adversarial Nets,直译为生成式对抗网络,是一种非监督式模型。

GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使G学习到数据的分布,

GAN网络最强大的地方就是可以帮助我们建立模型,而不像传统的网络那样是在已有模型上帮我们更新参数而已。同时,因为GAN网络是一种无监督的学习方式,它的泛化性非常好。


1 GAN模型

1.1网络结构

上图都描述了GAN的核心网络,在生成网络中,得到假的数据,然后和真的数据一起喂入判别模型,判别模型判断输入的样本是真是假,先训练识别网络,再训练生成网络,再训练识别网络,如此反复,直到平衡。

1.2具体过程

  1. 生成模型:比作是一个样本生成器,输入一个噪声/样本,然后把它包装成一个逼真的样子,也就是输出。

    • 生成网络是造样本,它的目的就是使得自己造样本的能力尽可能强,强到什么程度呢,判别网络没法判断我是真样本还是假样本。

    • 通常这个网络选用最普通的多层随机网络即可,网络太深容易引起梯度消失或者梯度爆炸。

  2. 判别模型:比作一个二分类器(如同0-1分类器),来判断输入的样本是真是假。(就是输出值大于0.5还是小于0.5)

    • 判别出来属于的一张图它是来自真实样本集还是假样本集。若输入的是真样本,输出就接近1,输出的是假样本,输出接近0。

训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点.。

纳什均衡是指博弈中这样的局面,对于每个参与者来说,只要其他人不改变策略,他就无法改善自己的状况。


上图是是论文中的一张过程图,判别分布(蓝色,虚线) ,生成数据的实际分布(黑色,虚线),数据的生成分布(绿色,实线)

(a) 对于D(判别网络)刚开始训练,有波动,但基本可以区分实际数据和生成数据;

(b) 随着训练的进行,D可以明显的区分实际数据和生成数据;

(c) 随着G的更新,绿色的线能够趋近于黑色的线;

(d) 经过几步训练,如果G和D有足够的能力,他们将达到平衡,辨别器无法区分两个分布,即D(x)= 1;

1.3训练结果

最终,训练结束后,生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。这是双方网路都得到利益最大化,不再改变自己的策略,也就是不再更新自己的权重。如果loss值很低,则生成器成功欺骗了识别器(把假数据当成和label一样也是1了),如果loss很大(label尽管是1,但是识别器还是预测为0,识别器判断出了真假),说明生成器还需提升)。


3 代码实现

  1. 避免使用RELU和pooling层,减少稀疏梯度的可能性,使用leakrelu激活函数;
  2. 最后一层的激活函数使用tanh;
  3. 在鉴别器中使用dropout;

3.1 Generative model:

model = Sequential()

model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()

noise = Input(shape=(self.latent_dim,))
img = model(noise)

3.2 Discriminator model:

model = Sequential()     
model.add(Flatten(input_shape=self.img_shape))     
model.add(Dense(512))     
model.add(LeakyReLU(alpha=0.2))     
model.add(Dense(256))     
model.add(LeakyReLU(alpha=0.2))     
model.add(Dense(1, activation='sigmoid'))     
model.summary()    

img = Input(shape=self.img_shape)     
validity = model(img)

3.3 GAN

discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

4 参考资料

https://www.jianshu.com/p/998cf8e52209

https://zhuanlan.zhihu.com/p/34287744

posted @ 2020-08-02 16:57  _LLLL  阅读(1434)  评论(0编辑  收藏  举报