GAN(生成对抗网络)以及keras实现
由于笔者水平有限,如有错,欢迎指正。
0 GAN的思想
GAN,全称为 Generative Adversarial Nets,直译为生成式对抗网络,是一种非监督式模型。
GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使G学习到数据的分布,
GAN网络最强大的地方就是可以帮助我们建立模型,而不像传统的网络那样是在已有模型上帮我们更新参数而已。同时,因为GAN网络是一种无监督的学习方式,它的泛化性非常好。
1 GAN模型
1.1网络结构
上图都描述了GAN的核心网络,在生成网络中,得到假的数据,然后和真的数据一起喂入判别模型,判别模型判断输入的样本是真是假,先训练识别网络,再训练生成网络,再训练识别网络,如此反复,直到平衡。
1.2具体过程
-
生成模型:比作是一个样本生成器,输入一个噪声/样本,然后把它包装成一个逼真的样子,也就是输出。
-
生成网络是造样本,它的目的就是使得自己造样本的能力尽可能强,强到什么程度呢,判别网络没法判断我是真样本还是假样本。
-
通常这个网络选用最普通的多层随机网络即可,网络太深容易引起梯度消失或者梯度爆炸。
-
-
判别模型:比作一个二分类器(如同0-1分类器),来判断输入的样本是真是假。(就是输出值大于0.5还是小于0.5)
- 判别出来属于的一张图它是来自真实样本集还是假样本集。若输入的是真样本,输出就接近1,输出的是假样本,输出接近0。
训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点.。
纳什均衡是指博弈中这样的局面,对于每个参与者来说,只要其他人不改变策略,他就无法改善自己的状况。
上图是是论文中的一张过程图,判别分布(蓝色,虚线) ,生成数据的实际分布(黑色,虚线),数据的生成分布(绿色,实线)
(a) 对于D(判别网络)刚开始训练,有波动,但基本可以区分实际数据和生成数据;
(b) 随着训练的进行,D可以明显的区分实际数据和生成数据;
(c) 随着G的更新,绿色的线能够趋近于黑色的线;
(d) 经过几步训练,如果G和D有足够的能力,他们将达到平衡,辨别器无法区分两个分布,即D(x)= 1;
1.3训练结果
最终,训练结束后,生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。这是双方网路都得到利益最大化,不再改变自己的策略,也就是不再更新自己的权重。如果loss值很低,则生成器成功欺骗了识别器(把假数据当成和label一样也是1了),如果loss很大(label尽管是1,但是识别器还是预测为0,识别器判断出了真假),说明生成器还需提升)。
3 代码实现
- 避免使用RELU和pooling层,减少稀疏梯度的可能性,使用leakrelu激活函数;
- 最后一层的激活函数使用tanh;
- 在鉴别器中使用dropout;
3.1 Generative model:
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
3.2 Discriminator model:
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
3.3 GAN
discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')