GAN-生成式对抗网络(keras实现)
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是最近超级火的一个无监督学习方法,它主要由两部分组成,一部分是生成模型G(generator),另一部分是判别模型D(discriminator),它的训练过程可大致描述如下:
生成模型通过接收一个随机噪声来生成图片,判别模型用来判断这个图片是不是“真实的”,也就是说,生成网络的目标是尽量生成真实的图片去欺骗判别网络,判别网络的目标就是把G生成的图片和真实的图片区分开来,从而构成一个动态的博弈过程。
GAN主要用来解决的问题是:在数据量不足的情况下,通过小型数据集去生成一些数据
从理论上来说,GAN系列神经网络可以用来模拟任何数据分布,但是目前更主要用于图像。
而事实也证明,GAN生成的数据是可以直接用在实际的图像问题上的,如行人重识别数据集,细粒度识别等。
(GAN的网络结构及训练流程)
下面是用keras实现的GAN:
1 from __future__ import print_function, division 2 3 from keras.datasets import mnist 4 from keras.layers import Input, Dense, Reshape, Flatten, Dropout 5 from keras.layers import BatchNormalization, Activation, ZeroPadding2D 6 from keras.layers.advanced_activations import LeakyReLU 7 from keras.layers.convolutional import UpSampling2D, Conv2D 8 from keras.models import Sequential, Model 9 from keras.optimizers import Adam 10 11 import matplotlib.pyplot as plt 12 13 import sys 14 15 import numpy as np 16 17 class GAN(): 18 def __init__(self): 19 # 定义输入图像尺寸及通道 20 self.img_rows = 28 21 self.img_cols = 28 22 self.channels = 1 23 self.img_shape = (self.img_rows, self.img_cols, self.channels) 24 self.latent_dim = 100 25 26 # 设置网络优化器 27 optimizer = Adam(0.0002, 0.5) 28 29 # 构建判别网络 30 self.discriminator = self.build_discriminator() 31 self.discriminator.compile(loss='binary_crossentropy', 32 optimizer=optimizer, 33 metrics=['accuracy']) 34 35 # 构建生成网络 36 self.generator = self.build_generator() 37 38 # 生成器根据噪声生成图像 39 z = Input(shape=(self.latent_dim,)) 40 img = self.generator(z) 41 42 # 在联合模型中,设置判别器参数不可训练 43 self.discriminator.trainable = False 44 45 # 判别器验证生成图像 46 validity = self.discriminator(img) 47 48 # 训练生成器来欺骗判别器 49 self.combined = Model(z, validity) 50 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) 51 52 53 # 生成器结构 54 def build_generator(self): 55 56 model = Sequential() 57 58 model.add(Dense(256, input_dim=self.latent_dim)) 59 model.add(LeakyReLU(alpha=0.2)) 60 model.add(BatchNormalization(momentum=0.8)) 61 model.add(Dense(512)) 62 model.add(LeakyReLU(alpha=0.2)) 63 model.add(BatchNormalization(momentum=0.8)) 64 model.add(Dense(1024)) 65 model.add(LeakyReLU(alpha=0.2)) 66 model.add(BatchNormalization(momentum=0.8)) 67 model.add(Dense(np.prod(self.img_shape), activation='tanh')) 68 model.add(Reshape(self.img_shape)) 69 70 model.summary() 71 72 noise = Input(shape=(self.latent_dim,)) 73 img = model(noise) 74 75 return Model(noise, img) 76 77 # 判别器结构 78 def build_discriminator(self): 79 80 model = Sequential() 81 82 model.add(Flatten(input_shape=self.img_shape)) 83 model.add(Dense(512)) 84 model.add(LeakyReLU(alpha=0.2)) 85 model.add(Dense(256)) 86 model.add(LeakyReLU(alpha=0.2)) 87 model.add(Dense(1, activation='sigmoid')) 88 model.summary() 89 90 img = Input(shape=self.img_shape) 91 validity = model(img) 92 93 return Model(img, validity) 94 95 # 定义训练过程 96 def train(self, epochs, batch_size=128, sample_interval=50): 97 (X_train, _), (_, _) = mnist.load_data() 98 99 X_train = X_train / 127.5 - 1. 100 X_train = np.expand_dims(X_train, axis=3) 101 102 valid = np.ones((batch_size, 1)) 103 fake = np.zeros((batch_size, 1)) 104 105 for epoch in range(epochs): 106 107 idx = np.random.randint(0, X_train.shape[0], batch_size) 108 imgs = X_train[idx] 109 110 noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) 111 112 gen_imgs = self.generator.predict(noise) 113 114 d_loss_real = self.discriminator.train_on_batch(imgs, valid) 115 d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) 116 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 117 118 noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) 119 120 # 根据判别器valid训练生成器 121 g_loss = self.combined.train_on_batch(noise, valid) 122 123 print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) 124 125 # 保存生成图像 126 if epoch % sample_interval == 0: 127 self.sample_images(epoch) 128 129 def sample_images(self, epoch): 130 r, c = 5, 5 131 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) 132 gen_imgs = self.generator.predict(noise) 133 134 gen_imgs = 0.5 * gen_imgs + 0.5 135 136 fig, axs = plt.subplots(r, c) 137 cnt = 0 138 for i in range(r): 139 for j in range(c): 140 axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') 141 axs[i,j].axis('off') 142 cnt += 1 143 fig.savefig("images/%d.png" % epoch) 144 plt.close() 145 146 147 if __name__ == '__main__': 148 gan = GAN() 149 gan.train(epochs=30000, batch_size=32, sample_interval=200)
程序初始运行结果如下:
训练完成后效果如下: