StackGAN原理与实现--(text to image 利用文本合成逼真的图像,使用tensorflow2.x实现)

StackGAN原理

StackGAN简介

StackGAN具有两个GAN堆叠在一起形成了一个能够生成高分辨率图像的网络。它分为两个阶段,Stage-I和Stage-II。 Stage-I网络生成具有基本颜色和粗略草图的低分辨率图像,并以文本嵌入为条件,而Stage-II网络获取由Stage-I网络生成的图像并生成以文本嵌入为条件的高分辨率图像。基本上,第二个网络可以纠正缺陷并添加细节,产生更逼真的高分辨率图像。
在StackGAN中,Stage-I与绘制基本形状有关,而Stage-II与校正由Stage-I网络生成的图像中的缺陷有关。 Stage-II还添加了更多细节,以使图像看起来更逼真。这两个阶段的生成器网络都是条件生成对抗网络(CGAN)。第一个GAN以文本描述为条件,而第二个网络以文本描述和第一个GAN生成的图像为条件。

StackGAN架构

StackGAN架构StackGAN是一个两阶段的网络。StackGAN由许多网络组成,这些网络如下:
●Stack-I GAN:文本编码器(text encoder),条件增强网络(Conditioning Augmentation network),生成网络(generator network),鉴别网络(discriminator network),嵌入压缩网络(embedding compressor network)
●Stack-II GAN:文本编码器,条件增强网络,生成网络,鉴别网络,嵌入压缩网络
StackGAN网络的两个阶段,第一步是生成尺寸为64x64的图像。 然后,第二阶段获取这些低分辨率图像,并生成尺寸为256x256的高分辨率图像。

文本编码器网络

文本编码器网络的唯一目的是将文本描述( t t t)转换为文本嵌入( ϕ t \phi_t ϕt)。 本章中,不会训练文本编码器网络。使用预训练的文本嵌入。 文本编码器网络将句子编码为1024维文本嵌入。文本编码器网络在两个阶段都是通用的。

条件增强网络

条件增强(conditioning augmentation,CA)网络从分布 N ( μ ( ϕ t ) , ∑ ( ϕ t ) ) \mathcal N(\mu(\phi_t),\sum(\phi_t)) N(μ(ϕt),(ϕt))中采样随机潜变量 c ^ \hat c c^。添加CA网络优点:
●增加了网络的随机性。
●通过捕获具有各种姿势和外观的对象,可以使生成网络更强大。
●它产生更多的图文对。 使用大量的图文对,可以训练一个抗干扰的强大网络。

获取条件增强变量(conditioning augmentation variable)

从文本编码器获取文本嵌入( ϕ t \phi_t ϕt)之后,将它们传输到一个全连接层以生成均值等于 μ 0 \mu_0 μ0和标准差等于 σ 0 \sigma_0 σ0的值,然后将它们用于创建对角线协方差矩阵–通过将 σ 0 \sigma_0 σ0放在矩阵 ∑ ( ϕ t ) \sum(\phi_t) (ϕt)的对角线。 最后,使用 μ 0 \mu_0 μ0 ∑ 0 \sum_0 0创建高斯分布,可以表示为:
N ( μ 0 ( ϕ t ) , Σ 0 ( ϕ t ) ) ( 1 ) \mathcal N(\mu_0(\phi_t),\Sigma_0(\phi_t))\qquad(1) N(μ0(ϕt),Σ0(ϕt))(1)
然后,从刚创建的高斯分布中采样
c ^ 0 = μ 0 + σ 0 N ( 0 , I ) ( 2 ) \hat c_0 = \mu_0 + \sigma_0 \mathcal N(0,I)\qquad (2) c^0=μ0+σ0N(0,I)(2)
为了采样 c ^ 0 \hat c_0 c^0,首先对 σ 0 \sigma_0 σ0进行逐元素乘法,然后将输出加到 μ 0 \mu_0 μ0上。

Stage-I

生成网络

Stage-I生成网络是具有多个上采样层的深度卷积神经网络。生成网络是CGAN,其条件是 c ^ 0 \hat c_0 c^0和随机变量 z z z。生成网络采用高斯条件变量 c ^ 0 \hat c_0 c^0和随机噪声变量 z z z并生成尺寸为64x64x3的图像。 z z z是从高斯分布 p z p_z pz采样的随机噪声变量(尺寸为 N z N_z Nz)。 生成器网络生成的图像可以表示为 s 0 = G 0 ( z , c ^ 0 ) s_0 = G_0(z,\hat c_0) s0=G0(z,c^0)

鉴别网络

鉴别网络是一个深度卷积神经网络,其中包含一系列下采样卷积层。下采样层从图像生成特征图,无论它们是来自真实数据分布 p d a t a p_{data} pdata的还是由生成器网络生成的图像。然后,将特征映射连接到文本嵌入。使用压缩和空间复制将文本嵌入转换为连接所需的格式。空间压缩和复制包括一个全连接层,该层用于将文本嵌入压缩为一个 N d N_d Nd维输出,然后通过在空间上复制文本来转换为 N d × N d × N d N_d\times N_d \times N_d Nd×Nd×Nd维张量。然后将特征图以及压缩的和空间复制的文本嵌入沿通道维合并。最后,具有一个节点的全连接层用于二分类。

损失函数

鉴别网络损失表示为:
L ( D 0 ) = E ( I 0 , t ) ∼ p d a t a [ l o g D 0 ( I 0 , ϕ t ) ] + E z ∼ p z , t ∼ p d a t a [ l o g ( 1 − D 0 ( G 0 ( z , c ^ 0 ) , ϕ t ) ) ] ( 3 ) \mathcal L^{(D_0)}=\mathbb E_{(I_0,t) \sim p_{data}}[logD_0(I_0,\phi_t)] + \mathbb E_{z \sim p_z,t\sim p_{data}}[log(1-D_0(G_0(z,\hat c_0),\phi_t))]\qquad(3) L(D0)=E(I0,t)pdata[logD0(I0,ϕt)]+Ezpz,tpdata[log(1D0(G0(z,c^0),ϕt))](3)
生成网络损失表示为:
L ( G 0 ) = E z ∼ p z , t ∼ p d a t a [ l o g ( 1 − D 0 ( G 0 ( z , c ^ 0 ) , ϕ t ) ) ] + λ D K L ( N ( μ 0 ( ϕ t ) , Σ 0 ( ϕ t ) ) ∥ N ( 0 , I ) ) ( 4 ) \mathcal L^{(G_0)}=\mathbb E_{z \sim p_z,t\sim p_{data}}[log(1-D_0(G_0(z,\hat c_0),\phi_t))] + \lambda D_{KL}(\mathcal N(\mu_0(\phi_t),\Sigma_0(\phi_t)) \| \mathcal N(0,I))\qquad(4) L(G0)=Ezpz,tpdata[log(1D0(G0(z,c^0),ϕt))]+λDKL(N(μ0(ϕt),Σ0(ϕt))N(0,I))(4)

Stage-II

Stage-II的主要组件是生成网络和鉴别网络。生成网络是编码器-解码器类型的网络。假设已经通过 s 0 s_0 s0保留了随机性,则在此阶段不使用随机噪声 z z z,其中 s 0 s_0 s0是Stage-I的生成网络生成的图像。
首先使用预训练的文本编码器生成高斯条件变量 c ^ \hat c c^。这将生成相同的文本嵌入 ϕ t \phi_t ϕt。Stage-I和Stage-II条件增强具有不同的全连接层,用于生成不同的均值和标准差。Stage-II GAN学会了在文本嵌入中捕获有用的信息,而该信息被Stage-I GAN忽略了。
Stack-I GAN生成的图像可能缺少生动的对象部分,它们可能包含形状变形,可能会忽略对于生成真实图像非常重要的重要细节。 Stack-II GAN建立在Stack-I GAN的输出上。 Stack-II GAN以Stack-I GAN生成的低分辨率图像和文本描述为条件。它通过校正缺陷产生高分辨率的图像。

生成网络

生成网络是深层卷积神经网络。Stage-I的结果(即低分辨率图像)通过几个下采样层生成图像特征。 然后,将图像特征和文本条件变量沿通道尺寸连接在一起。 之后,将连接的张量送入一些残差块,这些残差块学习跨图像和文本特征的多峰表示。最后一个操作的输出被输入到一组上采样层,它们会生成尺寸为256x256x3的高分辨率图像。

鉴别网络

鉴别网络是一个深度卷积神经网络,并且包含额外的下采样层,因为图像的大小比Stage-I中的鉴别网络大。鉴别器是一个可识别是否匹配的鉴别器,这能够更好地匹配图像和条件文本。在训练期间,鉴别器将真实图像及其对应的文本描述作为正样本对,而负样本对则由两组组成。第一组是具有不匹配文本嵌入的真实图像,而第二组是具有相应文本嵌入的合成图像。

损失函数

Stack-II GAN中的生成网络G和鉴别网络D也可以通过使鉴别网络的损失最大并使生成网络的损损失最小来训练。
生成器损失表示为:
L ( D 1 ) = E ( I , t ) ∼ p d a t a [ l o g D ( I , ϕ t ) ] + E s 0 ∼ p G 0 , t ∼ p d a t a [ l o g ( 1 − D ( G ( s 0 , c ^ ) , ϕ t ) ) ] ( 3 ) \mathcal L^{(D_1)}=\mathbb E_{(I,t) \sim p_{data}}[logD(I,\phi_t)] + \mathbb E_{s_0 \sim p_{G_0},t\sim p_{data}}[log(1-D(G(s_0,\hat c),\phi_t))]\qquad(3) L(D1)=E(I,t)pdata[logD(I,ϕt)]+Es0pG0,tpdata[log(1D(G(s0,c^),ϕt))](3)
两个生成网络都以文本嵌入为条件。主要区别是生成网络具有和作为输入,其中是Stage-I生成的图像,是CA变量。
鉴别网络损失表示为:
L ( G ) = E s 0 ∼ p G 0 , t ∼ p d a t a [ l o g ( 1 − D ( G ( s 0 , c ^ ) , ϕ t ) ) ] + λ D K L ( N ( μ ( ϕ t ) , Σ ( ϕ t ) ) ∥ N ( 0 , I ) ) ( 4 ) \mathcal L^{(G)}=\mathbb E_{s_0 \sim p_{G_0},t\sim p_{data}}[log(1-D(G(s_0,\hat c),\phi_t))] + \lambda D_{KL}(\mathcal N(\mu(\phi_t),\Sigma(\phi_t)) \| \mathcal N(0,I))\qquad(4) L(G)=Es0pG0,tpdata[log(1D(G(s0,c^),ϕt))]+λDKL(N(μ(ϕt),Σ(ϕt))N(0,I))(4)

数据集

CUB数据集是不同鸟类的图像数据集(http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),包含200种不同鸟类的11788张图像。char-CNN-RNN文本嵌入是预训练的文本嵌入(https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE)。

StackGAN实现

导入库

import os
import pickle
import time
import random
import PIL
import numpy as np
import pandas as pd
import tensorflow as tf

from PIL import Image
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt

Stage-I

def generator_c(x):
    mean = x[:,:128]
    log_sigma = x[:,128:]
    stddev = tf.exp(log_sigma)
    epsilon = tf.random.normal((mean.shape[1],),dtype=tf.int32)
    c = stddev * epsilon + mean
    return c

class CA(keras.Model):
    """
    Get conditioning augmentation model.
    Takes an embedding of shape (1024,) and returns a tensor of shape (256,)
    """
    def __init__(self):
        super(CA,self).__init__()
        self.fc = layers.Dense(256)
        self.activation  = layers.LeakyReLU(alpha=0.2)
    def call(self,inputs,training=False):
        x = self.activation(self.fc(inputs))
        return x

class Embedding_Compressor(keras.Model):
    """
    Build embedding compressor model
    """
    def __init__(self):
        super(Embedding_Compressor,self).__init__()
        self.fc = layers.Dense(128)
        self.activation = layers.ReLU()
    def call(self,inputs,training=False):
        x = self.activation(self.fc(inputs))
        return x

class Generator_stage1(keras.Model):
    """
    Builds a generator model used in Stage-I
    """
    def __init__(self):
        super(Generator_stage1,self).__init__()
        self.ca_fc = layers.Dense(256)
        self.ca_activation = layers.LeakyReLU(alpha=0.2)
        #self.lambda1 = layers.Lambda(generator_c)
        #self.mean1 = layers.Dense(128)
        #self.log_sigma1 = layers.Dense(128)
        self.fc1 = layers.Dense(128 * 8 * 4 * 4,use_bias=False)
        self.activation = layers.ReLU()
        
        self.upsampling1 = layers.UpSampling2D(size=(2,2))
        self.conv1 = layers.Conv2D(512,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac1 = layers.ReLU()
        
        self.upsampling2 = layers.UpSampling2D(size=(2,2))
        self.conv2 = layers.Conv2D(256,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()

        self.upsampling3 = layers.UpSampling2D(size=(2,2))
        self.conv3 = layers.Conv2D(128,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac3 = layers.ReLU()

        self.upsampling4 = layers.UpSampling2D(size=(2,2))
        self.conv4 = layers.Conv2D(64,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac4 = layers.ReLU()

        self.conv5 = layers.Conv2D(3,kernel_size=3,strides=1,padding='same',use_bias=False)

    def call(self,inputs,training=False):
        mean_logsigma = tf.split(self.ca_activation(self.ca_fc(inputs[0])),num_or_size_splits=2,axis=-1)
        #print(mean_logsigma.shape)
        #c = self.lambda1(mean_logsigma)
        #mean_logsigma_split = tf.split(mean_logsigma,num_or_size_splits=2,axis=-1)
        mean = mean_logsigma[0]
        log_sigma = mean_logsigma[1]
        stddev = tf.exp(log_sigma)
        c = stddev * inputs[2] + mean
        #print(c.shape)
        gen_inputs = tf.concat([c,inputs[1]],axis=1)
        #print(gen_inputs.shape)
        x = self.activation(self.fc1(gen_inputs))
        #print(x.shape)
        x = tf.reshape(x,shape=(-1,4,4,128*8))
        x = self.ac1(self.bn1(self.conv1(self.upsampling1(x)),training=training))
        x = self.ac2(self.bn2(self.conv2(self.upsampling2(x)),training=training))
        x = self.ac3(self.bn3(self.conv3(self.upsampling3(x)),training=training))
        x = self.ac4(self.bn4(self.conv4(self.upsampling4(x)),training=training))
        x = self.conv5(x)
        x = tf.tanh(x)
        #print(x.shape)
        return x,mean_logsigma

class Discriminator_stage1(keras.Model):
    def __init__(self):
        super(Discriminator_stage1,self).__init__()
        self.e_fc = layers.Dense(128)
        self.e_ac = layers.LeakyReLU(alpha=0.2)
        
        self.conv1 = layers.Conv2D(64,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.ac1 = layers.LeakyReLU(alpha=0.2)

        self.conv2 = layers.Conv2D(128,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.LeakyReLU(alpha=0.2)

        self.conv3 = layers.Conv2D(256,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.LeakyReLU(alpha=0.2)

        self.conv4 = layers.Conv2D(512,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.LeakyReLU(alpha=0.2)

        self.conv5 = layers.Conv2D(512,kernel_size=1,padding='same',strides=1)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.LeakyReLU(alpha=0.2)

        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)
    def call(self,inputs,training=False):
        x = self.ac1(self.conv1(inputs[0]))
        #print(x.shape)
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        #print(x.shape)
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        #print(x.shape)
        x = self.ac4(self.bn3(self.conv4(x),training=training))
        #print(x.shape)
        #print(x.shape)
        input_layer2 = self.e_ac(self.e_fc(inputs[1]))
        #print(input_layer2.shape)
        input_layer2 = tf.reshape(input_layer2,shape=(-1,1,1,128))
        #print(input_layer2.shape)
        input_layer2 = tf.tile(input_layer2,[1,4,4,1])
        #print(input_layer2.shape)
        x = tf.concat([x,input_layer2],axis=-1)
        #print(x.shape)
        x = self.ac5(self.bn4(self.conv5(x),training=training))
        #print(x.shape)
        x = self.flatten(x)
        x = self.fc(x)
        #print(x.shape)
        x = tf.sigmoid(x)
        return x

Stage-II

class Residual_block(layers.Layer):
    def __init__(self):
        super(Residual_block,self).__init__()
        self.conv1 = layers.Conv2D(128*4,kernel_size=(3,3),padding='same',stride=1)
        self.bn1 = layers.BatchNormalization()
        self.ac1 = layers.ReLU()
        self.conv2 = layers.Conv2D(128*4,kernel_size=(3,3),padding='same',strides=1)
        self.bn2 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()

    def call(self,inputs,training=False):
        x = self.bn1(self.conv1(inputs),training=training)
        x = self.ac1(x)
        x = self.bn2(self.conv2(x),training=training)
        x = layers.add([x,inputs])
        x = self.ac2(x)
        return x

class Generator_stage2(keras.Model):
    def __init__(self):
        super(Generator_stage2,self).__init__()
        self.ca_fc = layers.Dense(256)
        self.ca_activation = layers.LeakyReLU(alpha=0.2)
        #self.mean1 = layers.Dense(128)
        #self.log_sigma1 = layers.Dense(128)

        self.conv1 = layers.Conv2D(128,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.ac1 = layers.ReLU()
        self.conv2 = layers.Conv2D(256,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()
        self.conv3 = layers.Conv2D(512,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.ReLU()

        self.conv4 = layers.Conv2D(512,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.ReLU()

        self.rb1 = Residual_block()
        self.rb2 = Residual_block()
        self.rb3 = Residual_block()
        self.rb4 = Residual_block()
        
        self.upsampling1 = layers.UpSampling2D(size=(2,2))
        self.conv5 = layers.Conv2D(512,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.ReLU()

        self.upsampling2 = layers.UpSampling2D(size=(2,2))
        self.conv6 = layers.Conv2D(256,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn5 = layers.BatchNormalization()
        self.ac6 = layers.ReLU()

        self.upsampling3 = layers.UpSampling2D(size=(2,2))
        self.conv7 = layers.Conv2D(128,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn6 = layers.BatchNormalization()
        self.ac7 = layers.ReLU()

        self.upsampling4 = layers.UpSampling2D(size=(2,2))
        self.conv8 = layers.Conv2D(64,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn7 = layers.BatchNormalization()
        self.ac8 = layers.ReLU()

        self.conv9 = layers.Conv2D(3,kernel_size=3,strides=1,padding='same',use_bias=False)

    def call(self,inputs,training):
        #CA Network
        mean_logsigma = tf.split(self.ca_activation(self.ca_fc(inputs[0])),num_or_size_splits=2,axis=-1)
        #mean_logsigma = self.ca_activation(self.ca_fc(inputs[0]))
        mean = mean_logsigma[0]
        log_sigma = mean_logsigma[1]
        stddev = tf.exp(log_sigma)
        c = stddev * inputs[2] + mean
        #c = tf.concat([c,inputs[1]],axis=1)
        #Image Encoder
        x = self.ac1(self.conv1(inputs[1]))
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        c = tf.expand_dims(c,axis=1)
        c = tf.expand_dims(c,axis=1)
        c = tf.tile(c,[1,16,16,1])
        #Concatenation
        c_code = tf.concat([c,x],axis=3)
        #Residual Block
        x = self.ac4(self.bn3(self.conv4(c_code),training=training))
        x = self.rb1(x)
        x = self.rb2(x)
        x = self.rb3(x)
        x = self.rb4(x)
        #Upsampling block
        x = self.ac5(self.bn4(self.conv5(self.upsampling1(x)),training=training))
        x = self.ac6(self.bn5(self.conv6(self.upsampling2(x)),training=training))
        x = self.ac7(self.bn6(self.conv7(self.upsampling3(x)),training=training))
        x = self.ac8(self.bn7(self.conv8(self.upsampling4(x)),training=training))
        x = self.conv9(x)
        x = tf.tanh(x)
        return x,mean_logsigma

class Discriminator_stage2(keras.Model):
    def __init__(self):
        super(Discriminator_stage2,self).__init__()
        self.e_fc = layers.Dense(128)
        self.e_ac = layers.LeakyReLU(alpha=0.2)
        
        self.conv1 = layers.Conv2D(64,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.ac1 = layers.LeakyReLU(alpha=0.2)
        
        self.conv2 = layers.Conv2D(128,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.LeakyReLU(alpha=0.2)

        self.conv3 = layers.Conv2D(256,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.LeakyReLU(alpha=0.2)

        self.conv4 = layers.Conv2D(512,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.LeakyReLU(alpha=0.2)

        self.conv5 = layers.Conv2D(1024,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.LeakyReLU(alpha=0.2)

        self.conv6 = layers.Conv2D(2048,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn5 = layers.BatchNormalization()
        self.ac6 = layers.LeakyReLU(alpha=0.2)

        self.conv7 = layers.Conv2D(1024,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn6 = layers.BatchNormalization()
        self.ac7 = layers.LeakyReLU(alpha=0.2)

        self.conv8 = layers.Conv2D(512,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn7 = layers.BatchNormalization()

        self.conv9 = layers.Conv2D(128,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn8 = layers.BatchNormalization()
        self.ac8 = layers.LeakyReLU(alpha=0.2)

        self.conv10 = layers.Conv2D(128,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn9 = layers.BatchNormalization()
        self.ac9 = layers.LeakyReLU(alpha=0.2)

        self.conv11 = layers.Conv2D(512,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn10 = layers.BatchNormalization()

        self.ac10 = layers.LeakyReLU(alpha=0.2)

        self.conv12 = layers.Conv2D(64*8,kernel_size=1,strides=1,padding='same')
        self.bn11 = layers.BatchNormalization()
        self.ac11 = layers.LeakyReLU(alpha=0.2)

        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    def call(self,inputs,training=False):
        x = self.ac1(self.conv1(inputs[0]))
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        x = self.ac4(self.bn3(self.conv4(x),training=training))
        x = self.ac5(self.bn4(self.conv5(x),training=training))
        x = self.ac6(self.bn5(self.conv6(x),training=training))
        x = self.ac7(self.bn6(self.conv7(x),training=training))
        x = self.bn7(self.conv8(x))
        
        x2 = self.ac8(self.bn8(self.conv9(x),training=training))
        x2 = self.ac9(self.bn9(self.conv10(x2),training=training))
        x2 = self.bn10(self.conv11(x2))

        added_x = layers.add([x,x2])
        added_x = self.ac10(added_x)

        input_layer2 = self.e_ac(self.e_fc(inputs[1]))
        input_layer2 = tf.reshape(input_layer2,shape=(-1,1,1,128))
        input_layer2 = tf.tile(input_layer2,[1,4,4,1])
        x3 = tf.concat([added_x,input_layer2],axis=-1)

        x3 = self.ac11(self.bn11(self.conv12(x3),training=training))
        x3 = self.faltten(x3)
        x3 = self.fc(x3)
        x3 = tf.sigmoid(x3)
        return x3

数据准备

def load_class_ids(class_info_file_path):
    """
    Load class ids from class_info.pickle file
    """
    with open(class_info_file_path,'rb') as f:
        class_ids = pickle.load(f,encoding='latin1')
    return class_ids

def load_embeddings(embeddings_file_path):
    """
    Load embeddings
    """
    with open(embeddings_file_path,'rb') as f:
        embeddings = pickle.load(f,encoding='latin1')
        embeddings = np.array(embeddings)
    return embeddings

def load_filenames(filenames_file_path):
    """
    Load filenames.pickle file and return a list of all file names
    """
    with open(filenames_file_path,'rb') as f:
        filenames = pickle.load(f,encoding='latin1')
    return filenames

def load_bounding_boxes(dataset_dir):
    """
    Load bounding boxes and return a dictionary of file names and corresponding bounding boxes
    """
    #Paths
    bounding_boxes_path = os.path.join(dataset_dir,'bounding_boxes.txt')
    file_paths_path = os.path.join(dataset_dir,'images.txt')
    #Read bounding_boxes.txt and images.txt file
    df_bounding_boxes = pd.read_csv(bounding_boxes_path,
                                    delim_whitespace=True,
                                    header=None).astype(int)
    df_file_names = pd.read_csv(file_paths_path,
                                delim_whitespace=True,
                                header=None)
    #create a list of file names
    file_names = df_file_names[1].tolist()
    #create a dictionary of file_names and bounding boxes
    filename_boundingbox_dict = {img_file[:-4]:[] for img_file in file_names[:2]}
    #Assign a bounding box to the corresponding image
    for i in range(0,len(file_names)):
        #Get the bounding box
        bounding_box = df_bounding_boxes.iloc[i][1:].tolist()
        key = file_names[i][:-4]
        filename_boundingbox_dict[key] = bounding_box
    return filename_boundingbox_dict

def get_img(img_path,bbox,image_size):
    """
    Load and resize image
    """
    img = Image.open(img_path).convert('RGB')
    width,height = img.size
    if bbox is not None:
        R = int(np.maximum(bbox[2],bbox[3]) * 0.75)
        center_x = int((2 * bbox[0] + bbox[2]) / 2)
        center_y = int((2 * bbox[1] + bbox[3]) / 2)
        y1 = np.maximum(0,center_y - R)
        y2 = np.minimum(height,center_y + R)
        x1 = np.maximum(0,center_x - R)
        x2 = np.minimum(width,center_x + R)
        img = img.crop([x1,y1,x2,y2])
    img = img.resize(image_size,PIL.Image.BILINEAR)
    return img

def load_dataset(filenames_file_path, cub_dataset_dir, embeddings_file_path, image_size):
    """
    Load dataset
    """
    filenames = load_filenames(filenames_file_path)
    #class_ids = load_class_ids(class_info_file_path)
    bounding_boxes = load_bounding_boxes(cub_dataset_dir)
    all_embeddings = load_embeddings(embeddings_file_path)
    print("Embeddings shape:",all_embeddings.shape)
    X,embeddings = [],[]
    for index,filename in enumerate(filenames):
        bounding_box = bounding_boxes[filename]
        #Load images
        img_name = '{}/images/{}.jpg'.format(cub_dataset_dir,filename)
        img = get_img(img_name,bounding_box,image_size)
        all_embeddings1 = all_embeddings[index,:,:]
        embedding_ix = random.randint(0,all_embeddings1.shape[0] - 1)
        embedding = all_embeddings1[embedding_ix,:]
        img = np.array(img,dtype=np.float32)
        img = img / 127.5 -1.
        X.append(img)
        embeddings.append(embedding)
    return tf.data.Dataset.from_tensor_slices((X,embeddings))

损失函数


def celoss_zeros(logits):
	# 计算属于与标签为0的交叉熵,使用标签平滑
    y = tf.ones_like(logits) * 0.1
    loss = keras.losses.binary_crossentropy(y,logits)
    return tf.reduce_mean(loss)

def celoss_ones(logits):
    # 计算属于与标签为1的交叉熵,使用标签平滑
    y = tf.ones_like(logits) * 0.9
    loss = keras.losses.binary_crossentropy(y, logits)
    return tf.reduce_mean(loss)

def KL_loss(logits):
    mean = logits[0]
    logsigma = logits[1]
    loss = -logsigma + 0.5 * (-1 + tf.exp(2. * logsigma) + tf.square(mean))
    loss = tf.reduce_mean(loss)
    return loss

def d_loss_fn(batch_size,generator,discriminator,img_batch,embedding_batch,z_noise,condition_var,training):
    # 采样生成图片
    fake_images,_ = generator([embedding_batch,z_noise,condition_var],training)
    # 判定生成图片
    d_fake_logits = discriminator([fake_images,embedding_batch], training)
    d_loss_fake = celoss_zeros(d_fake_logits)
    # 判定真实图片
    d_real_logits = discriminator([img_batch,embedding_batch], training)
    d_loss_real = celoss_ones(d_real_logits)
    # 判定不符嵌入
    d_wrong_logits = discriminator([img_batch[:(batch_size-1)],embedding_batch[1:]],training)
    d_loss_wrong = celoss_zeros(d_wrong_logits)
    loss = d_loss_fake + d_loss_real + d_loss_wrong
    return loss

def g_loss_fn(generator,discriminator,embedding_batch,z_noise,condition_var,training):
    fake_images,mean_logsigma = generator([embedding_batch,z_noise,condition_var],training)
    d_fake_logits = discriminator([fake_images,embedding_batch], training)
    d_loss_fake = celoss_ones(d_fake_logits)
    d_KL_fake = KL_loss(mean_logsigma)
    loss = d_loss_fake + 2.0 * d_KL_fake
    return loss

def d_loss_fn_stage2(batch_size=64,
                     gen_stage1=None,
                     gen_stage2=None,
                     dis_stage2=None,
                     image_batch=None,
                     embedding_batch=None,
                     z_noise=None,
                     condition_var=None,
                     training=False):
    lr_fake_images,_ = gen_stage1([embedding_batch,z_noise,condition_var])
    hr_fake_images,_ = gen_stage2([embedding_batch,lr_fake_images,condition_var],training)
    # 判定生成图片
    d_fake_logits = dis_stage2([hr_fake_images,embedding_batch], training)
    d_loss_fake = celoss_zeros(d_fake_logits)
    # 判定真实图片
    d_real_logits = dis_stage2([image_batch,embedding_batch], training)
    d_loss_real = celoss_ones(d_real_logits)
    # 判定不符嵌入
    d_wrong_logits = dis_stage2([image_batch[:(batch_size-1)],embedding_batch[1:]],training)
    d_loss_wrong = celoss_zeros(d_wrong_logits)
    loss = d_loss_fake + d_loss_real + d_loss_wrong
    return loss

def g_loss_fn_stage2(gen_stage1=None,
                     gen_stage2=None,
                     dis_stage2=None,
                     embedding_batch=None,
                     z_noise=None,
                     condition_var=None,
                     training=False):
    lr_fake_images,_ = gen_stage1([embedding_batch,z_noise,condition_var])
    hr_fake_images,mean_logsigma = gen_stage2([embedding_batch,lr_fake_images,condition_var],training)
    d_fake_logits = dis_stage2([hr_fake_images,embedding_batch], training)
    d_loss_fake = celoss_ones(d_fake_logits)
    d_KL_fake = KL_loss(mean_logsigma)
    loss = d_loss_fake + 2.0 * d_KL_fake
    return loss

图片保存函数

def save_result(val_out,val_block_size,image_path,color_mode):
    def preprocessing(img):
        img = ((img + 1.0)*(255./2)).astype(np.uint8)
        return img

    preprocessed = preprocessing(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocessed[b,:,:,:]
        else:
            single_row = np.concatenate((single_row,preprocessed[b,:,:,:]),axis=1)
        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image).save(image_path)

模型训练

def main_stage1():
    data_dir = "./birds/"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    image_size = 64
    batch_size = 32
    z_dim = 100
    stage1_generator_lr = 0.0002
    stage1_discriminator_lr = 0.0002
    stage1_lr_decay_step = 600
    epochs = 10000
    condition_dim = 128
    training=True

    embeddings_file_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
    embeddings_file_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"

    filenames_file_path_train = train_dir + "/filenames.pickle"
    filenames_file_path_test = test_dir + "/filenames.pickle"

    #class_info_file_path_train = train_dir + "/class_info.pickle"
    #class_info_file_path_test = test_dir + "/class_info.pickle"

    cub_dataset_dir = "./CUB_200_2011/" + "/CUB_200_2011"

    d_optimizer = keras.optimizers.Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
    g_optimizer = keras.optimizers.Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)

    """
    X_test, embeddings_test = load_dataset(filenames_file_path=filenames_file_path_test,
                                                       cub_dataset_dir=cub_dataset_dir,
                                                       embeddings_file_path=embeddings_file_path_test,
                                                       image_size=(64, 64))
    """
    
    db_train = load_dataset(filenames_file_path=filenames_file_path_train,
                                                          cub_dataset_dir=cub_dataset_dir,
                                                          embeddings_file_path=embeddings_file_path_train,
                                                          image_size=(64, 64))
    db_train = db_train.shuffle(1000).batch(batch_size,drop_remainder=True)
    db_test = load_dataset(filenames_file_path=filenames_file_path_test,
                                                       cub_dataset_dir=cub_dataset_dir,
                                                       embeddings_file_path=embeddings_file_path_test,
                                                       image_size=(64, 64))
    #num_repeat = int(100 / batch_size) + 1
    db_test = iter(db_test.batch(64,drop_remainder=True).repeat(3))
    
    gen = Generator_stage1()
    gen.build([[4,1024],[4,100],[128]])

    dis = Discriminator_stage1()
    dis.build([[4,64,64,3],[4,1024]])

    #real_labels = np.ones((batch_size, 1), dtype=float) * 0.9
    #fake_labels = np.zeros((batch_size, 1), dtype=float) * 0.1
    for epoch in range(epochs):
        g_losses = []
        d_losses = []
        for index,(x,embedding) in enumerate(db_train):
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                d_loss = d_loss_fn(batch_size,gen,dis,x,embedding,z_noise,condition_var,training)
            grads = tape.gradient(d_loss,dis.trainable_variables)
            d_optimizer.apply_gradients(zip(grads,dis.trainable_variables))
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                g_loss = g_loss_fn(gen,dis,embedding,z_noise,condition_var,training)
            grads = tape.gradient(g_loss,gen.trainable_variables)
            g_optimizer.apply_gradients(zip(grads,gen.trainable_variables))
        if epoch % 2 == 0:
            print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))
            #可视化
            z = tf.random.normal([64,z_dim])
            _,embeddings_test = next(db_test)
            condition_var = tf.random.normal(shape=(condition_dim,))
            fake_image,_ = gen([embeddings_test,z,condition_var],training=False)
            img_path = r'gan-{}.png'.format(epoch)
            save_result(fake_image.numpy(),8,img_path,color_mode='P')
            d_losses.append(float(d_loss))
            g_losses.append(float(g_loss))
    gen.save_weights("stage1_gen.h5")
    dis.save_weights("stage1_dis.h5")

def main_stage2():
    data_dir = "data/birds/"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    image_size = 256
    batch_size = 32
    z_dim = 100
    stage1_generator_lr = 0.0002
    stage1_discriminator_lr = 0.0002
    stage1_lr_decay_step = 600
    epochs = 10000
    condition_dim = 128
    training=True

    embeddings_file_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
    embeddings_file_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"

    filenames_file_path_train = train_dir + "/filenames.pickle"
    filenames_file_path_test = test_dir + "/filenames.pickle"

    #class_info_file_path_train = train_dir + "/class_info.pickle"
    #class_info_file_path_test = test_dir + "/class_info.pickle"

    cub_dataset_dir = data_dir + "/CUB_200_2011"

    d_optimizer = keras.optimizers.Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
    g_optimizer = keras.optimizers.Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)

    #Load dataset
    """
    db_lr_train = load_dataset(filenames_file_path=filenames_file_path_train,
                                                          cub_dataset_dir=cub_dataset_dir,
                                                          embeddings_file_path=embeddings_file_path_train,
                                                          image_size=(64, 64))
    db_lr_train = db_lr_train.shuffle(1000).batch(batch_size,drop_remainder=True)
    db_lr_test = load_dataset(filenames_file_path=filenames_file_path_test,
                                                       cub_dataset_dir=cub_dataset_dir,
                                                       embeddings_file_path=embeddings_file_path_test,
                                                       image_size=(64, 64))
    #num_repeat = int(100 / batch_size) + 1
    db_lr_test = iter(db_lr_test.batch(64,drop_remainder=True).repeat(3))
    """

    #Load dataset
    db_hr_train = load_dataset(filenames_file_path=filenames_file_path_train,
                                                          cub_dataset_dir=cub_dataset_dir,
                                                          embeddings_file_path=embeddings_file_path_train,
                                                          image_size=(256,256))
    db_hr_train = db_hr_train.shuffle(1000).batch(batch_size,drop_remainder=True)
    db_hr_test = load_dataset(filenames_file_path=filenames_file_path_test,
                                                       cub_dataset_dir=cub_dataset_dir,
                                                       embeddings_file_path=embeddings_file_path_test,
                                                       image_size=(256,256))
    #num_repeat = int(100 / batch_size) + 1
    db_hr_test = iter(db_hr_test.batch(64,drop_remainder=True).repeat(3))

    gen_stage1 = Generator_stage1()
    gen_stage1.build([[4,1024],[4,100],[128]])
    try:
        gen_stage1.load_weights("stage1_gen.h5")
    except Exception as e:
        print(e)

    gen_stage2 = Generator_stage2()
    gen_stage2.build([[4,1024],[4,64,64,3],[128]])

    dis_stage2 = Discriminator_stage2()
    dis_stage2.build([[4,256,256,3],[4,1024]])

    for epoch in range(epochs):
        g_losses = []
        d_losses = []
        for index,(x,embedding) in enumerate(db_hr_train):
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                d_loss = d_loss_fn_stage2(batch_size=batch_size,
                                          gen_stage1=gen_stage1,
                                          gen_stage2=gen_stage2,
                                          dis_stage2=dis_stage2,
                                          image_batch=x,
                                          embedding_batch=embedding,
                                          z_noise=z_noise,
                                          condition_var=condition_var,
                                          training=training)
            grads = tape.gradient(d_loss,dis_stage2.trainable_variables)
            d_optimizer.apply_gradients(zip(grads,dis_stage2.trainable_variables))
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                g_loss = g_loss_fn_stage2(gen_stage1=gen_stage1,
                                          gen_stage2=gen_stage2,
                                          dis_stage2=dis_stage2,
                                          embedding_batch=embedding,
                                          z_noise=z_noise,
                                          condition_var=condition_var,
                                          training=training)
            grads = tape.gradient(g_loss,gen_stage2.trainable_variables)
            g_optimizer.apply_gradients(zip(grads,gen_stage2.trainable_variables))
        if epoch % 100 == 0:
            print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))
            #可视化
            z = tf.random.normal([64,z_dim])
            _,embeddings_test = next(db_hr_test)
            condition_var = tf.random.normal(shape=(condition_dim,))
            lr_fake_image,_ = gen_stage1([embeddings_test,z,condition_var],training=False)
            hr_fake_image,_ = gen_stage2([embeddings_test,lr_fake_image,condition_var],training=False)
            img_path = r'gan-{}.png'.format(epoch)
            save_result(hr_fake_image.numpy(),8,img_path,color_mode='P')
            d_losses.append(float(d_loss))
            g_losses.append(float(g_loss))
    gen_stage2.save_weights("stage2_gen.h5")
    dis_stage2.save_weights("stage2_dis.h5")

运行

if __name__ == "__main__":
    main_stage1()
    main_stage2()
posted @ 2020-11-03 14:35  盼小辉丶  阅读(1021)  评论(0编辑  收藏  举报