实验16-使用GAN生成手写数字样本
from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt import sys import os import numpy as np class GAN(): def __init__(self): # --------------------------------- # # 行28,列28,也就是mnist的shape # --------------------------------- # self.img_rows = 28 self.img_cols = 28 self.channels = 1 # 28,28,1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 # adam优化器 optimizer = Adam(0.0002, 0.5) self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) self.generator = self.build_generator() gan_input = Input(shape=(self.latent_dim,)) img = self.generator(gan_input) # 在训练generate的时候不训练discriminator self.discriminator.trainable = False # 对生成的假图片进行预测 validity = self.discriminator(img) self.combined = Model(gan_input, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): # --------------------------------- # # 生成器,输入一串随机数字 # --------------------------------- # model = Sequential() model.add(Dense(256, input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) noise = Input(shape=(self.latent_dim,)) img = model(noise) return Model(noise, img) def build_discriminator(self): # ----------------------------------- # # 评价器,对输入进来的图片进行评价 # ----------------------------------- # model = Sequential() # 输入一张图片 model.add(Flatten(input_shape=self.img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) # 判断真伪 model.add(Dense(1, activation='sigmoid')) img = Input(shape=self.img_shape) validity = model(img) return Model(img, validity) def train(self, epochs, batch_size=128, sample_interval=50): # 获得数据 (X_train, _), (_, _) = mnist.load_data() # 进行标准化 X_train = X_train / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) # 创建标签 valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------------- # # 随机选取batch_size个图片 # 对discriminator进行训练 # --------------------------- # idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------------- # # 训练generator # --------------------------- # noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) g_loss = self.combined.train_on_batch(noise, valid) print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) if epoch % sample_interval == 0: self.sample_images(epoch) def sample_images(self, epoch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("images/%d.png" % epoch) plt.close() if __name__ == '__main__': if not os.path.exists("./images"): os.makedirs("./images") gan = GAN() gan.train(epochs=3000, batch_size=256, sample_interval=200)
运行结果
分类:
知识图谱实验
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人