GAN生成对抗网络-text to image原理与基本实现-文字转图像-11
实质上这是一个RNN的词语向量化模型 + 条件GAN
首先用一个RNN网络来将文字转换为向量,然后将生成的文
本向量加入到G和D网络中。
与普通GAN不同的是,这里多了一种错误情况,即看上去挺
真的,但是对应的描述与图不符合,也要给与惩罚。
如果不加的话,那么D所能获得的信息仅仅是G的生成图,失
去了判断图与描述是否符合的判断能力。
为什么还需要噪声输入?
这是因为一般情况下很多时候一句话就是描述内容(花的样
子)的,而不会描述style(style主要是包括背景和姿态)。
那么这种情况下我们就希望噪声能起到这种加入style的作用
,从而生成更加真实多样化的图片。
另外,通过特征可视化的方式,让z具有specific的style加入
功能,从而解决文本描述本身不对style进行任何阐述的问题,
随机化的z可以加入不同的style,从而增加生成样本的真实
性与多样性。
三个重点部分:
一、对于文本的处理,如何提取文本信息,作为我们生成器
的条件?
首先是如何文本的向量化
然后提取文本信息
二、对于图片的处理,需要添加负面的训练:
即:输入的文本和图片不对应的时候,要给出惩罚。
做出输入的队列:
正确的图片 + 正确的文本
错误的图片 + 错误的文本
三、创建输入队列
保证文本和图片对应
import tensorflow as tf
from gensim.models import word2vec
from gensim.models import Word2Vec
import pandas as pd
import glob
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline
from IPython import display
os.listdir('../input/gan-text-to-image-102flowers-rieyuguanghua')
n_input = 100
n_hidden = 128
image_height = 64
image_width = 64
image_depth = 3
noise_dim = 100
maxlength = 250
NUM_EPOCHS = 100
batch_size = 64
if not os.path.exists('102flowers'):
!mkdir 102flowers
!tar zxvf ../input/102flowersdataset/102flowers.tgz -C ./102flowers/
display.clear_output()
all_text_filename = glob.glob('../input/cvpr2016/cvpr2016_flowers/text_c10/class_*/image_*.txt')
all_text_filename.sort(key=lambda x:x.split('/')[-1])
all_image_filename = glob.glob('./102flowers/jpg/*.jpg')
all_image_filename.sort()
all_text_filename = np.array(all_text_filename)
all_image_filename = np.array(all_image_filename)
wrong_image_filename = all_image_filename[np.random.permutation(len(all_image_filename))]
dataset_image = tf.data.Dataset.from_tensor_slices((all_image_filename, wrong_image_filename))
if not os.path.exists('../input/gan-text-to-image-102flowers-rieyuguanghua/all_text.txt'):
with open('all_text.txt', 'at') as f:
for a_text in all_text_filename:
f.write(open(a_text).read().replace('\n', '') + '\n')
if not os.path.exists('../input/gan-text-to-image-102flowers-rieyuguanghua/word_model'):
sentences = word2vec.Text8Corpus('all_text.txt')
model = word2vec.Word2Vec(sentences, size=100)
model.save('word_model')
else:
model = Word2Vec.load('../input/gan-text-to-image-102flowers-rieyuguanghua/word_model')
!cp ../input/gan-text-to-image-102flowers-rieyuguanghua/all_text.txt ./
!cp ../input/gan-text-to-image-102flowers-rieyuguanghua/word_model ./
word_vectors = model.wv
maxlength = max([len(open(a_text).read().split()) for a_text in all_text_filename])
n_steps = maxlength
def pad(x, maxlength=200):
x1 = np.zeros((maxlength,100))
x1[:len(x)] = x
return x1
def text_vec(text_filenames):
vec = []
for a_text in text_filenames:
all_word = open(a_text).read().split()
all_vec = [word_vectors[w] for w in all_word if w in word_vectors]
vec.append(all_vec)
data = pd.Series(vec)
data = data.apply(pad, maxlength=maxlength)
data_ = np.concatenate(data).reshape(len(data),maxlength,100)
return data_
data_text_emb = text_vec(all_text_filename)
def read_image(image_filename):
image = tf.read_file(image_filename)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_crop_or_pad(image, 512, 512)
image = tf.image.resize_images(image, (256, 256))
#image = tf.image.convert_image_dtype(image, tf.float32)
image = (image - tf.reduce_min(image))/(tf.reduce_max(image) - tf.reduce_min(image))
return image
def _pre_func(real_image_name, wrong_image_name):
wrong_image = read_image(wrong_image_name)
real_image = read_image(real_image_name)
return real_image, wrong_image
dataset_image = dataset_image.map(_pre_func)
dataset_image = dataset_image.batch(batch_size)
iterator = tf.data.Iterator.from_structure(dataset_image.output_types, dataset_image.output_shapes)
real_image_batch, wrong_image_batch = iterator.get_next()
input_text = tf.placeholder(tf.float32, [None, n_steps, n_input])
inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
def length(shuru):
return tf.reduce_sum(tf.sign(tf.reduce_max(tf.abs(shuru),reduction_indices=2)),reduction_indices=1)
def text_rnn(input_text, batch_size=64, reuse= tf.AUTO_REUSE):
cell = tf.contrib.rnn.GRUCell(n_hidden,
kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001),
bias_initializer = tf.truncated_normal_initializer(stddev=0.0001),
reuse=reuse)
output, _ = tf.nn.dynamic_rnn(
cell,
input_text,
dtype=tf.float32,
sequence_length = length(input_text)
)
index = tf.range(0,batch_size)*n_steps + (tf.cast(length(input_text),tf.int32) - 1)
flat = tf.reshape(output,[-1,int(output.get_shape()[2])])
last = tf.gather(flat,index)
return last
def get_generator(noise_img, image_depth, condition_label, is_train=True, alpha=0.2):
with tf.variable_scope("generator", reuse= tf.AUTO_REUSE):
# 100 x 1 to 4 x 4 x 512
# 全连接层
noise_img = tf.to_float(noise_img)
noise_img = tf.layers.dense(noise_img, n_hidden)
noise_img = tf.maximum(alpha * noise_img, noise_img)
noise_img_ = tf.concat([noise_img, condition_label], 1)
layer1 = tf.layers.dense(noise_img_, 4*4*512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
layer1 = tf.layers.batch_normalization(layer1, training=is_train)
layer1 = tf.nn.relu(layer1)
# batch normalization
#layer1 = tf.layers.batch_normalization(layer1, training=is_train)
# ReLU
#layer1 = tf.nn.relu(layer1)
# dropout
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 4 x 4 x 512 to 8 x 8 x 256
layer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same')
layer2 = tf.layers.batch_normalization(layer2, training=is_train)
layer2 = tf.nn.relu(layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 8 x 8 256 to 16x 16 x 128
layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=is_train)
layer3 = tf.nn.relu(layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 16 x 16 x 128 to 32 x 32 x 64
layer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same')
layer4 = tf.layers.batch_normalization(layer4, training=is_train)
layer4 = tf.nn.relu(layer4)
# 64 x 64 x 32
layer5 = tf.layers.conv2d_transpose(layer4, 32, 3, strides=2, padding='same')
layer5 = tf.layers.batch_normalization(layer5, training=is_train)
layer5 = tf.nn.relu(layer5)
# 128 x 128 x 16
layer6 = tf.layers.conv2d_transpose(layer5, 16, 3, strides=2, padding='same')
layer6 = tf.layers.batch_normalization(layer6, training=is_train)
layer6 = tf.nn.relu(layer6)
# 256 x 256 x 3
logits = tf.layers.conv2d_transpose(layer6, image_depth, 3, strides=2, padding='same')
outputs = tf.tanh(logits)
outputs = (outputs/2) + 0.5
outputs = tf.clip_by_value(outputs, 0.0, 1.0)
return outputs
def get_discriminator(inputs_img, condition_label, reuse= tf.AUTO_REUSE, alpha=0.2):
with tf.variable_scope("discriminator", reuse=reuse):
# 256 x 256 x 3 to 128 x 128 x 16
# 第一层不加入BN
layer1 = tf.layers.conv2d(inputs_img, 16, 3, strides=2, padding='same')
layer1 = tf.maximum(alpha * layer1, layer1)
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 128 x 128 x 16 to 64 x 64 x 32
layer2 = tf.layers.conv2d(layer1, 32, 3, strides=2, padding='same')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 32 x 32 x 64
layer3 = tf.layers.conv2d(layer2, 64, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 16*16*128
layer4 = tf.layers.conv2d(layer3, 128, 3, strides=2, padding='same')
layer4 = tf.layers.batch_normalization(layer4, training=True)
layer4 = tf.maximum(alpha * layer4, layer4)
# 8*8*256
layer5 = tf.layers.conv2d(layer4, 256, 3, strides=2, padding='same')
layer5 = tf.layers.batch_normalization(layer5, training=True)
layer5 = tf.maximum(alpha * layer5, layer5)
# 4*4*512
layer6 = tf.layers.conv2d(layer5, 512, 3, strides=2, padding='same')
layer6 = tf.layers.batch_normalization(layer6, training=True)
layer6 = tf.maximum(alpha * layer6, layer6)
text_emb = tf.layers.dense(condition_label, 512)
text_emb = tf.maximum(alpha * text_emb, text_emb)
text_emb = tf.expand_dims(text_emb, 1)
text_emb = tf.expand_dims(text_emb, 2)
text_emb = tf.tile(text_emb, [1,4,4,1])
layer_concat = tf.concat([layer6, text_emb], 3)
layer7 = tf.layers.conv2d(layer_concat, 512, 1, strides=1, padding='same')
layer7 = tf.layers.batch_normalization(layer7, training=True)
layer7 = tf.maximum(alpha * layer7, layer7)
flatten = tf.reshape(layer7, (-1, 4*4*512))
logits = tf.layers.dense(flatten, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
def get_loss(inputs_image, wrong_image, inputs_noise, condition_label, image_depth, smooth=0.1):
g_outputs = get_generator(inputs_noise, image_depth, condition_label, is_train=True)
d_logits_real, d_outputs_real = get_discriminator(inputs_image, condition_label)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, condition_label, reuse=tf.AUTO_REUSE)
d_logits_wrong, d_outputs_wrong = get_discriminator(wrong_image, condition_label, reuse=tf.AUTO_REUSE)
print(inputs_image.get_shape(), condition_label.get_shape())
# 计算Loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_outputs_fake)*(1-smooth)))
#g_loss_l1 = tf.reduce_mean(tf.abs(g_outputs - inputs_image))
#g_loss = g_loss_ + g_loss_l1
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_outputs_real)*(1-smooth)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_outputs_fake)*smooth))
d_loss_wrong = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_wrong,
labels=tf.ones_like(d_outputs_wrong)*smooth))
d_loss = d_loss_real + d_loss_fake + d_loss_wrong
return g_loss, d_loss
def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
# Optimizer
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
return g_opt, d_opt
def plot_images(samples):
#samples = (samples+1)/2
fig, axes = plt.subplots(nrows=1, ncols=10, sharex=True, sharey=True, figsize=(20,2))
for img, ax in zip(samples, axes):
ax.imshow(img.reshape((256, 256, 3)))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0)
def show_generator_output(sess, n_images, inputs_noise, output_dim, test_text_vec):
# condition_text = tf.to_float(condition_text)
# last, b_size = sess.run(text_vec(condition_text, batch_size=n_images, reuse=tf.AUTO_REUSE))
samples = sess.run(get_generator(inputs_noise, output_dim, test_text_vec, is_train=False))
return samples
# 定义参数
n_samples = 10
learning_rate = 0.0002
beta1 = 0.5
# 存储loss
losses = []
step = 0
last = text_rnn(input_text)
g_loss, d_loss = get_loss(real_image_batch, wrong_image_batch, inputs_noise, last, image_depth, smooth=0.1)
g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
saver = tf.train.Saver()
with tf.Session() as sess:
#sess.run(tf.global_variables_initializer())
model_file=tf.train.latest_checkpoint('../input/gan-text-to-image-102flowers-rieyuguanghua')
saver.restore(sess, model_file)
for epoch in range(791, 831):
index = np.random.permutation(len(all_image_filename))
data_text_emb = data_text_emb[index]
all_image_filename = all_image_filename[index]
wrong_image_filename = all_image_filename[np.random.permutation(len(all_image_filename))]
dataset_image = tf.data.Dataset.from_tensor_slices((all_image_filename, wrong_image_filename))
dataset_image = dataset_image.map(_pre_func)
dataset_image = dataset_image.repeat(1)
dataset_image = dataset_image.batch(batch_size)
dataset_image_op = iterator.make_initializer(dataset_image)
sess.run(dataset_image_op)
i = 0
while True:
try:
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_dim))
text_emb_batch = data_text_emb[i: i + batch_size]
i = i + batch_size
_ = sess.run([g_train_opt, d_train_opt], feed_dict={input_text: text_emb_batch,
inputs_noise: batch_noise})
# if step % 50 == 0:
# saver.save(sess, "./model10.ckpt")
# train_loss_d = d_loss.eval({input_text: text_emb_batch,
# inputs_noise: batch_noise})
# train_loss_g = g_loss.eval({input_text: text_emb_batch,
# inputs_noise: batch_noise})
#
# losses.append((train_loss_d, train_loss_g))
# print("Step {}....".format(step+1),
# "Discriminator Loss: {:.4f}....".format(train_loss_d),
# "Generator Loss: {:.4f}....". format(train_loss_g))
# 显示图片
step += 1
#except tf.errors.OutOfRangeError as e:
except:
# saver.save(sess, "./model10.ckpt")
print('epoch', epoch, 'step', step)
#print(e)
#try:
# sess.run(real_image_batch)
#except Exception as e:
# print(e)
break
if epoch%10 == 0:
#saver.save(sess, "./model10.ckpt")
n_samples = 10
condition_text = data_text_emb[:n_samples]
test_noise = np.random.uniform(-1, 1, size=[n_samples, noise_dim])
last_test = text_rnn(input_text, batch_size=n_samples, reuse=tf.AUTO_REUSE)
test_text_vec = sess.run(last_test, feed_dict={input_text: condition_text})
samples = show_generator_output(sess, n_samples, test_noise, 3, test_text_vec)
plot_images(samples)
saver.save(sess, "./model11.ckpt")
tf.reset_default_graph()
vec = []
test_word = """
the petals on this flower are yellow with a red center,the petals on this flower are yellow with a red center
"""
all_vec = [word_vectors[w] for w in test_word if w in word_vectors]
vec.append(all_vec)
data = pd.Series(vec)
data = data.apply(pad, maxlength=maxlength)
data_ = np.concatenate(data).reshape(len(data),maxlength,100)
test_text_vec = data_
test_text_vec = test_text_vec.astype(np.float32)
losses = []
step = 0
n_samples = 10
test_noise = np.random.uniform(-1, 1, size=[n_samples, noise_dim])
last_test = text_rnn(test_text_vec, batch_size=n_samples, reuse=tf.AUTO_REUSE)
new_image = get_generator(test_noise, image_depth, last_test)
saver = tf.train.Saver()
with tf.Session() as sess:
model_file=tf.train.latest_checkpoint('../input/gan-text-to-image-102flowers-rieyuguanghua')
saver.restore(sess, model_file)
samples = show_generator_output(sess, n_samples, test_noise, 3, last_test)
plot_images(samples)