ACGAN(Auxiliary Classifier GAN)原理与实现(tensorflow2.x实现)

ACGAN原理

ACGAN的原理GAN(CGAN)相似。对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造图像。对于CGAN,判别器的输入是图像(包含假的或真实的图像)及其标签, 输出是图像属于真实图像的概率。对于ACGAN,判别器的输入是一幅图像,而输出是该图像属于真实图像的概率以及其类别概率。
ACGAN架构本质上,在CGAN中,向网络提供了标签。在ACGAN中,使用辅助解码器网络重建辅助信息。ACGAN理论认为,强制网络执行其他任务可以提高原始任务的性能。在这种情况下,辅助任务是图像分类。原始任务是生成伪造图像。
判别器目标函数:
L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z l o g [ 1 − D ( G ( z ∣ y ) ) ] − E x ∼ p d a t a p ( c ∣ x ) − E z l o g p ( c ∣ g ( z ∣ y ) ) \mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog[1 − D(G(z|y))]-\mathbb E_{x\sim p_{data}}p(c|x)-\mathbb E_zlogp(c|g(z|y)) L(D)=ExpdatalogD(x)Ezlog[1D(G(zy))]Expdatap(cx)Ezlogp(cg(zy))
生成器目标函数:
L ( G ) = − E z l o g D ( g ( z ∣ y ) ) − E z l o g p ( c ∣ g ( z ∣ y ) ) \mathcal L^{(G)} = -\mathbb E_{z}logD(g(z|y))-\mathbb E_zlogp(c|g(z|y)) L(G)=EzlogD(g(zy))Ezlogp(cg(zy))

ACGAN实现

模块导入

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import math
from PIL import Image

生成器

def generator(inputs,image_size,activation='sigmoid',labels=None):
    """生成网络
    Arguments:
        inputs (layer): 输入
        image_size (int): 图片尺寸
        activation (string): 输出层激活函数
        labels (tensor): 标签
    returns:
        model: 生成网络
    """
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128,64,32,1]
    inputs = [inputs,labels]
    x = keras.layers.concatenate(inputs,axis=1)
    
    x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)
    x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)
    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2DTranspose(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    if activation is not None:
        x = keras.layers.Activation(activation)(x)
    return keras.Model(inputs,x,name='generator')

鉴别器

def discriminator(inputs,activation='sigmoid',num_labels=None):
    """生成网络
    Arguments:
        inputs (Layer): 输入
        activation (string): 输出层激活函数
        num_labels (int): 类别数
    Returns:
        Model: 鉴别网络
    """
    kernel_size = 5
    layer_filters = [32,64,128,256]
    x = inputs
    for filters in layer_filters:
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = keras.layers.LeakyReLU(0.2)(x)
        x = keras.layers.Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    x = keras.layers.Flatten()(x)
    outputs = keras.layers.Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = keras.layers.Activation(activation)(outputs)
    if num_labels:
        #ACGAN有第二个输出,用于输出图片的类别
        layer = keras.layers.Dense(layer_filters[-2])(x)
        labels = keras.layers.Dense(num_labels)(layer)
        labels = keras.layers.Activation('softmax',name='label')(labels)
        outputs = [outputs,labels]
    return keras.Model(inputs,outputs,name='discriminator')

模型构建

def build_and_train_models():
    """The ACGAN training
    """
    #数据加载及预处理
    (x_train,y_train),_ = keras.datasets.mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train,[-1,image_size,image_size,1])
    x_train = x_train.astype('float32') / 255.
    num_labels = len(np.unique(y_train))
    y_train = keras.utils.to_categorical(y_train)

    #超参数
    model_name = 'acgan-mnist'
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size,image_size,1)
    label_shape = (num_labels,)
    
    #discriminator
    inputs = keras.layers.Input(shape=input_shape,name='discriminator_input')
    discriminator = discriminator(inputs,num_labels=num_labels)
    optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)
    loss = ['binary_crossentropy','categorical_crossentropy']
    discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
    discriminator.summary()

    #generator
    input_shape = (latent_size,)
    inputs = keras.layers.Input(shape=input_shape,name='z_input')
    labels = keras.layers.Input(shape=label_shape,name='labels')
    generator = generator(inputs,image_size,labels=labels)
    generator.summary()
    optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)
    discriminator.trainable = False
    adversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),
            name=model_name)
    adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
    adversarial.summary()

    models = (generator,discriminator,adversarial)
    data = (x_train,y_train)
    params = (batch_size,latent_size,train_steps,num_labels,model_name)
    train(models,data,params)

模型训练

def train(models,data,params):
    """Train the discriminator and adversarial Networks
    Arguments:
        models (list): generator,discriminator,adversarial
        data (list): x_train,y_train
        params (list): network parameter
    """
    generator,discriminator,adversarial = models
    x_train,y_train = data
    batch_size,latent_size,train_steps,num_labels,model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])
    noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]
    train_size = x_train.shape[0]
    print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))
    for i in range(train_steps):
        #训练鉴别器
        rand_indexes = np.random.randint(0,train_size,size=batch_size)
        real_images = x_train[rand_indexes]
        real_labels = y_train[rand_indexes]
        #产生伪造图片
        noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
        fake_images = generator.predict([noise,fake_labels])
        #构造输入
        x = np.concatenate((real_images,fake_images))
        #训练类别标签
        labels = np.concatenate((real_labels,fake_labels))
        #标签
        y = np.ones([2*batch_size,1])
        y[batch_size:,:] = 0.0
        #训练模型
        metrics = discriminator.train_on_batch(x,[y,labels])
        fmt = '%d: [disc loss: %f, srcloss: %f],'
        fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'
        log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])

        #train adversarial network for 1 batch
        noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
        y = np.ones([batch_size,1])
        metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])
        fmt = "%s [advr loss: %f, srcloss: %f,"
        fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"
        log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])
        print(log)
        if (i + 1) % save_interval == 0:
            # 绘制生成图片
            plot_images(generator,noise_input=noise_input,
                    noise_label=noise_label,show=False,
                    step=(i + 1),
                    model_name=model_name)
    generator.save(model_name + ".h5")

虚假图像生成及绘制plot_images函数

def plot_images(generator,
                noise_input,
                noise_label=None,
                noise_codes=None,
                show=False,
                step=0,
                model_name="gan"):
    """生成虚假图片及绘制

    # Arguments
        generator (Model): 生成模型
        noise_input (ndarray): 潜在模型
        show (bool): 是否展示
        step (int): step值
        model_name (string): 模型名称

    """
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

训练结果

#运行
if __name__ == '__main__':
    build_and_train_models()
step=1000:

训练1000steps

step=15000:

训练15000steps

posted @ 2020-10-15 13:42  盼小辉丶  阅读(888)  评论(0编辑  收藏  举报