生成对抗网络GAN详解与代码

1.GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。

以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:

(1)优化D:

优化第一项是真是样本x输入的时候,结果越大越好;对于噪声等的输入z,生成的假样本G(z)要越小越好

(2)优化G:

优化生成器时和真是样本没关系,故不需要考虑;这时候只有假样本,但生成器希望假样本越逼真越好(接近1),故D(G(z)越大越好,则最小化1-D(G(z))

 2.GAN的特点:

    (1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式

    (2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本

3. GAN 的优点:

    (1) GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链

    (2)相比其他所有模型, GAN可以产生更加清晰,真实的样本

    (3)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

    (4)相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊

    (5)相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的

    (6)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。

 4. GAN的缺点:

    (1)训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多

    (2)GAN不适合处理离散形式的数据,比如文本

    (3)GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

5.为什么GAN中的优化器不常用SGD

    (1)SGD容易震荡,容易使GAN训练不稳定,

    (2)GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。

6.训练GAN的一些技巧

(1). 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)

(2). 使用wassertein GAN的损失函数

(3). 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑

(4). 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm

(5). 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数

(6). 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,

(7). 给D的网络层增加高斯噪声,相当于是一种正则

7.GAN实战

import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入os
 
    
def xavier_init(size): #初始化参数时使用的xavier_init函数
    in_dim = size[0] 
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化标准差
    return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果

X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的样本(即真实的手写数字)

D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量 
D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量
theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合

Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵
 
G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量 
G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量
theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合

def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
    return np.random.uniform(-1., 1., size=[m, n])

def generator(z): #生成器,z的维度为[N, 100]
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, 128]
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, 784]
    G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, 784] 
    return G_prob #返回G_prob

def discriminator(x): #判别器,x的维度为[N, 784]
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, 128]
    D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, 1]
    D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, 1]
    return D_prob, D_logit #返回D_prob, D_logit

G_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果

#对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, targets=tf.ones_like(D_logit_real))) 
#对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.zeros_like(D_logit_fake))) 
#判别器的误差
D_loss = D_loss_real + D_loss_fake 
#生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.ones_like(D_logit_fake))) 

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器

mb_size = 128 #训练的batch_size
Z_dim = 100 #生成器输入的随机噪声的列的维度
  
sess = tf.Session() #会话层
sess.run(tf.initialize_all_variables()) #初始化所有可训练参数

def plot(samples): #保存图片时使用的plot函数
    fig = plt.figure(figsize=(4, 4)) #初始化一个4行4列包含16张子图像的图片
    gs = gridspec.GridSpec(4, 4) #调整子图的位置
    gs.update(wspace=0.05, hspace=0.05) #置子图间的间距
    for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r') 
    return fig


path = '/data/User/zcc/' #保存可视化结果的路径
i = 0 #训练过程中保存的可视化结果的索引 
for it in range(1000000): #训练100万次
    if it % 1000 == 0: #每训练1000次就保存一下结果
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
        fig = plot(samples) #通过plot函数生成可视化结果
        plt.savefig(path+'out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可视化结果
        i += 1
        plt.close(fig)
 
    X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入)
 
    #下面是得到训练一次的结果,通过sess来run出来
    _, D_loss_curr, D_loss_real, D_loss_fake, D_loss = sess.run([D_solver, D_loss, D_loss_real, D_loss_fake, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
 
    if it % 1000 == 0: #每训练1000次输出一下结果
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

 参考博客:

https://blog.csdn.net/m0_37407756/article/details/75309670

https://blog.csdn.net/jiongnima/article/details/80033169

posted @ 2019-07-24 11:21  USTC丶ZCC  阅读(9489)  评论(0编辑  收藏  举报