Tensorflow2.0实战之GAN

本文主要带领读者了解生成对抗神经网络(GAN),并使用提供的face数据集训练网络

GAN 入门

自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。
就在三年前,Ian Goodfellow 在 reddit 上回答 GAN 是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。
在这里插入图片描述

“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。
GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。
因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。
例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。
但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。
由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”

但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目。

GAN 原理

GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。
在这里插入图片描述
为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。
判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。
ok,接下来我们就来实现一下

准备阶段

下载数据集
数据集,笔者这里已经为大家提供了,链接如下:
链接: https://pan.baidu.com/s/15wFZAANvr8gajiVY_1mI0A
提取码: c9vy
解压数据集
将下载好的数据集解压,放在工程目录下
在这里插入图片描述
加载数据集
加载数据集的代码,笔者这里直接提供给大家了,下面只是展示部分代码,文末会提供完整项目的代码链接

import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img
    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size
    return dataset, img_shape, len_dataset
def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):

构建网络
搭建Generator,Generator包含两个部分,init部分和前向传播的call部分,代码如下

class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        # z:[b,100]-->[b,3*3*512]-->[b,3,3,512]-->[b,64,64,3]
        self.fc=keras.layers.Dense(3*3*512)

        self.conv1=keras.layers.Conv2DTranspose(256,3,3,'valid')  # 反卷积
        self.bn1=keras.layers.BatchNormalization()

        self.conv2=keras.layers.Conv2DTranspose(128,5,2,'valid')
        self.bn2=keras.layers.BatchNormalization()

        self.conv3=keras.layers.Conv2DTranspose(3,4,3,'valid')

    def call(self, inputs, training=None, mask=None):
        # [z,100]-->[z,3*3*512]
        x=self.fc(inputs)
        x=tf.reshape(x,[-1,3,3,512])
        x=tf.nn.leaky_relu(x)

        x=tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
        x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        x=self.conv3(x)
        x=tf.tanh(x)
        return x

搭建Discriminator,同上

class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        # [b,64,64,3]-->[b,1]
        self.conv1=keras.layers.Conv2D(64,5,3,'valid')

        self.conv2=keras.layers.Conv2D(128,5,3,'valid')
        self.bn2=keras.layers.BatchNormalization()

        self.conv3=keras.layers.Conv2D(256,5,3,'valid')
        self.bn3=keras.layers.BatchNormalization()

        # [b,h,w,c]-->[b,-1]
        self.flatten=keras.layers.Flatten()
        # [b,-1]-->[b,1]
        self.fc=keras.layers.Dense(1)
    def call(self, inputs, training=None, mask=None):
        x=tf.nn.leaky_relu(self.conv1(inputs))
        x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        x=tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
        x=self.flatten(x)
        logits=self.fc(x)
        return logits

训练GAN
定义相关数据,包括epoch,lr等等
这些数据可以自定义,笔者这里就不改动了

	 z_dim = 100
    epochs = 50000
    batch_size = 512
    learning_rate = 0.0002
    is_training = True

加载数据

	img_path=glob.glob(r'E:\python_pro\TF2.0\GAN\faces\*.jpg')
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)

可以打印查看数据集信息:

(512, 64, 64, 3), (64, 64, 3)
(512, 64, 64, 3) ,1.0, -1.0

定义优化器,注意我们在开始训练时,需要新建训练GAN图片的文件,为查看数据提供持久化依据

    for epoch in range(epochs):

        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))


        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))

            z = tf.random.uniform([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('GAN_IMAGE', 'gan%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')

训练结果

接下来我们来看看,训练的效果图,注意,GAN的训练过程是非常非常非常慢的,大概训练十几个小时,才能有个比较好的效果,有的数据集甚至会训练几天之久,这个随数据集的大小和对最终效果的要求来定的。笔者这个数据集比较的简单,只是给大家做演示,好了,废话就不过多的说了,上图
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
上述分别是训练了100epoch、500、1500、4000的效果图,可以看到随着训练的次数增加,效果因为越来越好了

总结

大家在训练GAN时,还是需要一个好一些的GPU显卡才行,这样可以体验GPU给我们带来的加速效果。这样会使得训练的速度大大加快。
笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。
最后还是希望大家动手实践实践,共同进步。
最终的代码链接:https://github.com/huzixuan1/TF_2.0/tree/master/GAN

posted @ 2022-05-11 16:35  陶陶Name  阅读(192)  评论(0编辑  收藏  举报