生成模型应用——使用变分自编码器(VAE)控制人脸属性生成人脸图片

使用VAE生成人脸图片

变分自编码器(VAE)的基础知识参考博文变分自编码器(VAE)原理与实现(tensorflow2.x)。作为VAE的应用,我们将使用VAE生成一些可控制属性的人脸图片。可用的人脸数据集包括:

  1. Celeb A: 这是在学术界很流行的数据集,因为它包含面部特征的注释,不能用于商业用途。
  2. Flickr-Faces-HQ Dataset, FFHQ: 该数据集可免费用于商业用途,并包含高分辨率图像。

网络架构

网络的输入形状为(112,112,3),可以在数据预处理时调整图片尺寸。由于数据集较复杂,可以增加滤波器数量以增加网络容量。因此,编码器中的卷积层如下:

a) Conv2D(filters = 32, kernel_size=(3,3), strides = 2)
b) Conv2D(filters = 32, kernel_size=(3,3), strides = 2)
c) Conv2D(filters = 64, kernel_size=(3,3), strides = 2)
d) Conv2D(filters = 64, kernel_size=(3,3), strides = 2)

人脸重建

让我们先看一下VAE的重构图像效果:
重建图片
尽管重建的图片并不完美,但它们至少看起来不错。 VAE设法从输入图像中学习了一些特征,并使用它们来绘制新的面孔。可以看出,VAE可以更好地重建女性面孔。这是由于Celeb_A数据集中女性的比例较高。这也就是为什么男性的肤色更趋向年轻、女性化。
观察图像背景,由于图像背景的多样性,因此编码器无法将每个细节编码至低维度,因此我们可以看到VAE对背景颜色进行编码,而解码器则基于这些颜色创建模糊的背景。

生成新面孔

为了生成新图像,我们从标准的高斯分布中采样随机数,并将传递给解码器:

z_samples = np.random.normal(loc=0., scale=1, size=(image_num, z_dim))
images = vae.decoder(z_samples.astype(np.float32))

但,某些生成的面孔看起来太恐怖了!
生成新面孔我们可以使用采样技巧来提高图像保真度。

采样技巧

可以看到,训练后的VAE可以很好地重建人脸。但,随机抽样潜变量生成的图像中存在问题。为了调试该问题,将数据集中图像输入到VAE解码器中,以获取潜在空间的均值和方差。然后,绘制了每个潜在空间变量的均值:
分布情况
从理论上讲,它们应该以0为均值且方差为1,但随机采样的样本并不总是与解码器期望的分布匹配。这是采样技巧技巧的地方,收集潜在变量的平均标准差(一个标量值),该标准差用于生成正态分布的样本(200维)。然后,在其中添加了平均均值(200个维度)。
yep,现在生成的图片看起来好多了!
生成图片

接下来,将介绍如何进行面部属性编辑,而不是生成随机的面孔。

控制人脸属性

潜在空间

本质上,潜在空间意味着潜在变量的每个可能值。在我们的VAE中,它是200个维度的向量(或者称200个变量)。我们希望每个变量都包含独特的语义,例如z[0]代表眼睛,z[1]代表眼睛的颜色,依此类推,事情从来没有那么简单。假设信息是在所有潜在向量中编码的,就可以使用向量算术探索潜在空间。

属性控制

使用一个二维示例解释属性控制的原理。假设现在在地图上的(0,0)点,而目的地位于(x, y)。因此,朝目的地的方向是(x-0, y-0)除以(x, y)的L2范数,可以将方向表示为(x_dot, y_dot)。因此,每次移动(x_dot, y_dot)时,都在朝着目的地移动。每次移动(-2 * x_dot, -2 * y_dot)时,将以两倍的步幅远离目的地。
类似的,如果我们知道了微笑属性的方向向量,则可以将其添加到潜在变量中以使人脸附加微笑属性:

new_z_samples = z_samples +  smiling_magnitude*smiling_vector

smile_magnitude是我们设置的标量值,因此下一步是找出获取smile_vector的方法。

查找属性向量

Celeb A数据集附带每个图像的面部属性注释。标签是二进制的,指示图像中是否存在特定属性。我们将使用标签和编码的潜在变量来找到我们的方向向量:

  1. 使用测试数据集或训练数据集中的样本,并使用VAE解码器生成潜矢量。
  2. 将潜在向量分为两组:具有(正向量)或不具有(负向量)的我们感兴趣的一个属性。
  3. 分别计算正向量和负向量的平均值。
  4. 通过从平均正向量中减去平均负向量来获取属性方向向量。
    在预处理函数中,返回我们感兴趣的属性的标签。然后,使用lambda函数映射到数据管道:
def preprocess_attrib(sample, attribute):
    image = sample['image']
    image = tf.image.resize(image, [112,112])
    image = tf.cast(image, tf.float32)/255.
    return image, sample['attributes'][attribute]
ds = ds.map(lambda x: preprocess_attrib(x, attribute))

人脸属性编辑

提取属性向量后,进行以下操作:

  1. 首先,我们从数据集中获取图像,将其放在首位,作为对比。
  2. 将人脸编码为潜在变量,然后对其进行解码以生成新人脸,并将其放置在中间。
  3. 然后,我们向右逐渐增加属性向量。
  4. 同样,我们向左逐渐减少属性向量。

下图显示了通过内插潜在向量生成的图像:
生成图片
生成图片
生成图片
接下来,我们可以尝试一起更改多个面部属性。在下图中,左侧的图像是随机生成的,并用作基准。右侧是经过一些潜在空间运算后的新图像:
小部件修改多个面部属性
该小部件可在Jupyter notebook中使用。

完整代码

# vae_faces.ipynb
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Layer, Input, Conv2D, Dense, Flatten, Reshape, Lambda, Dropout
from tensorflow.keras.layers import Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalization
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import tensorflow_datasets as tfds

import cv2
import numpy as np
import matplotlib.pyplot as plt
import datetime, os
import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)

strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
print('Number of devidex: {}'.format(num_devices))

(ds_train, ds_test_), ds_info = tfds.load(
    'celeb_a',
    split=['train', 'test'],
    shuffle_files=True,
    with_info=True)
fig = tfds.show_examples(ds_train, ds_info)

batch_size = 32 * num_devices

def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image, [112,112])
    image = tf.cast(image, tf.float32) / 255.
    return image, image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(128)
ds_train = ds_train.batch(batch_size, drop_remainder=True).prefetch(batch_size)

ds_test = ds_test_.map(preprocess).batch(batch_size, drop_remainder=True).prefetch(batch_size)

train_num = ds_info.splits['train'].num_examples
test_num = ds_info.splits['test'].num_examples

class GaussianSampling(Layer):
    def call(self, inputs):
        means, logvar = inputs
        epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)
        samples = means + tf.exp(0.5 * logvar) * epsilon

        return samples

class DownConvBlock(Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):
        super(DownConvBlock, self).__init__(name=f"DownConvBlock_{DownConvBlock.count}")
        DownConvBlock.count += 1
        self.forward = Sequential([
            Conv2D(filters, kernel_size, strides, padding),
            BatchNormalization(),
            LeakyReLU(0.2)
        ])
    
    def call(self, inputs):
        return self.forward(inputs)

class UpConvBlock(Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), padding='same'):
        super(UpConvBlock, self).__init__(name=f"UpConvBlock_{UpConvBlock.count}")
        UpConvBlock.count += 1
        self.forward = Sequential([
            Conv2D(filters, kernel_size, 1, padding),
            LeakyReLU(0.2),
            UpSampling2D((2,2))
        ])
    
    def call(self, inputs):
        return self.forward(inputs)

class Encoder(Layer):
    def __init__(self, z_dim, name='encoder'):
        super(Encoder, self).__init__(name=name)

        self.features_extract = Sequential([
            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
            Flatten()
        ])

        self.dense_mean = Dense(z_dim, name='mean')
        self.dense_logvar = Dense(z_dim, name='logvar')
        self.sampler = GaussianSampling()
    
    def call(self, inputs):
        x = self.features_extract(inputs)
        mean = self.dense_mean(x)
        logvar = self.dense_logvar(x)
        z = self.sampler([mean, logvar])
        return z, mean, logvar

class Decoder(Layer):
    def __init__(self, z_dim, name='decoder'):
        super(Decoder, self).__init__(name=name)

        self.forward = Sequential([
            Dense(7*7*64, activation='relu'),
            Reshape((7,7,64)),
            UpConvBlock(filters=64, kernel_size=(3,3)),
            UpConvBlock(filters=64, kernel_size=(3,3)),
            UpConvBlock(filters=32, kernel_size=(3,3)),
            UpConvBlock(filters=32, kernel_size=(3,3)),
            Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')
        ])
    
    def call(self, inputs):
        return self.forward(inputs)

class VAE(Model):
    def __init__(self, z_dim, name='VAE'):
        super(VAE, self).__init__(name=name)
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
        self.mean = None
        self.logvar = None
    
    def call(self, inputs):
        z, self.mean, self.logvar = self.encoder(inputs)
        out = self.decoder(z)
        return out

if num_devices > 1:
    with strategy.scope():
        vae = VAE(z_dim=200)
else:
    vae = VAE(z_dim=200)

def vae_kl_loss(y_true, y_pred):
    kl_loss = -0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))
    return kl_loss

def vae_rc_loss(y_true, y_pred):
    rc_loss = tf.keras.losses.MSE(y_true, y_pred)
    return rc_loss

def vae_loss(y_true, y_pred):
    kl_loss = vae_kl_loss(y_true, y_pred)
    rc_loss = vae_rc_loss(y_true, y_pred)
    kl_weight_const = 0.01
    return kl_weight_const * kl_loss + rc_loss

model_path = "vae_faces_cele_a.h5"

checkpoint = ModelCheckpoint(
    model_path,
    monitor='vae_rc_loss',
    verbose=1,
    save_best_only=True,
    mode='auto',
    save_weights_only=True
)
early = EarlyStopping(
    monitor='vae_rc_loss',
    mode='auto',
    patience=3
)

callbacks_list = [checkpoint, early]

initial_learning_rate = 1e-3
steps_per_epoch = int(np.round(train_num/batch_size))
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=steps_per_epoch,
    decay_rate=0.96,
    staircase=True
)

vae.compile(
    loss=[vae_loss],
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),
    metrics=[vae_kl_loss, vae_rc_loss]
)
history = vae.fit(ds_train, validation_data=ds_test,epochs=50,callbacks=callbacks_list)

images, labels = next(iter(ds_train))
vae.load_weights(model_path)
outputs = vae.predict(images)

# Display
grid_col = 8
grid_row = 2

f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2, grid_row*2))

i = 0
for row in range(0, grid_row, 2):
    for col in range(grid_col):
        axarr[row,col].imshow(images[i])
        axarr[row,col].axis('off')
        axarr[row+1,col].imshow(outputs[i])
        axarr[row+1,col].axis('off')        
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)        
plt.show()

avg_z_mean = []
avg_z_std = []
for i in range(steps_per_epoch):
    images, labels = next(iter(ds_train))
    z, z_mean, z_logvar = vae.encoder(images)
    avg_z_mean.append(np.mean(z_mean, axis=0))
    avg_z_std.append(np.mean(np.exp(0.5*z_logvar),axis=0))
avg_z_mean = np.mean(avg_z_mean, axis=0)
avg_z_std = np.mean(avg_z_std, axis=0)

plt.plot(avg_z_mean)
plt.ylabel("Average z mean")
plt.xlabel("z dimension")

grid_col = 10
grid_row = 10

f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col, 1.5*grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row, col].hist(z[:,i], bins=20)
        # axarr[row, col].axis('off')
        i += 1

#f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()

z_dim = 200
z_samples = np.random.normal(loc=0, scale=1, size=(25, z_dim))
images = vae.decoder(z_samples.astype(np.float32))
grid_col = 7
grid_row = 2

f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col,2*grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row, col].imshow(images[i])
        axarr[row, col].axis('off')
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()

# 采样技巧
z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(25, z_dim))
z_samples += avg_z_mean

images = vae.decoder(z_samples.astype(np.float32))
grid_col = 7
grid_row = 2

f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row,col].imshow(images[i])
        axarr[row,col].axis('off')
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()

(ds_train, ds_test), ds_info = tfds.load(
    'celeb_a',
    split=['train', 'test'],
    shuffle_files=True,
    with_info=True)
test_num = ds_info.splits['test'].num_examples

def preprocess_attrib(sample, attribute):
    image = sample['image']
    image = tf.image.resize(image, [112, 112])
    image = tf.cast(image, tf.float32) / 255.
    return image, sample['attributes'][attribute]

def extract_attrib_vector(attribute, ds):
    batch_size = 32 * num_devices
    ds = ds.map(lambda x: preprocess_attrib(x, attribute))
    ds = ds.batch(batch_size)

    steps_per_epoch = int(np.round(test_num / batch_size))

    pos_z = []
    pos_z_num = []
    neg_z = []
    neg_z_num = []

    for i in range(steps_per_epoch):
        images, labels = next(iter(ds))
        z, z_mean, z_logvar = vae.encoder(images)
        z = z.numpy()
        step_pos_z = z[labels==True]
        pos_z.append(np.mean(step_pos_z, axis=0))
        pos_z_num.append(step_pos_z.shape[0])

        step_neg_z = z[labels==False]
        neg_z.append(np.mean(step_neg_z, axis=0))
        neg_z_num.append(step_neg_z.shape[0])
    
    avg_pos_z = np.average(pos_z, axis=(0), weights=pos_z_num)
    avg_neg_z = np.average(neg_z, axis=(0), weights=neg_z_num)
    attrib_vector = avg_pos_z - avg_neg_z
    return attrib_vector

attributes = list(ds_info.features['attributes'].keys())
attribs_vectors = {}
for attrib in attributes:
    print(attrib)
    attribs_vectors[attrib] = extract_attrib_vector(attrib, ds_test)

def explore_latent_variable(image, attrib):
    grid_col = 8
    grid_row = 1

    z_samples,_,_ = vae.encoder(tf.expand_dims(image,0))
    f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))

    i = 0
    row = 0
    step = -3

    axarr[0].imshow(image)
    axarr[0].axis('off')
    for col in range(1, grid_col):
        new_z_samples = z_samples + step * attribs_vectors[attrib]
        reconstructed_image = vae.decoder(new_z_samples)

        step += 1
        axarr[col].imshow(reconstructed_image[0])
        axarr[col].axis('off')
        i += 1
    
    f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
    plt.show()

ds_test1 = ds_test.map(preprocess).batch(100)
images, labels = next(iter(ds_test1))

# 控制属性向量生成人脸图片
explore_latent_variable(images[34], 'Male')
explore_latent_variable(images[20], 'Eyeglasses')
explore_latent_variable(images[38], "Chubby")

fname = ""
if fname:
    # using existing image from file
    image = cv2.imread(fname)
    image = image[:,:,::-1]

    # crop
    min_dim = min(h, w)
    h_gap = (h-min_dim) // 2
    w_gap = (w-min_dim) // 2
    image = image[h_gap:h-h_gap, w_gap,w-w_gap, :]

    image = cv2.resize(image, (112,112))
    plt.imshow(image)

    # encode
    input_tensor = np.expand_dims(image, 0)
    input_tensor = input_tensor.astype(np.float32) / 255.
    z_samples = vae.encoder(input_tensor)
else:
    # start with random image
    z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(1, 200))

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

@interact
def explore_latent_variable(Male = (-5,5,0.1),
                            Eyeglasses = (-5,5,0.1),
                            Young = (-5,5,0.1),
                            Smiling = (-5,5,0.1),
                            Blond_Hair = (-5,5,0.1),
                            Pale_Skin = (-5,5,0.1),
                            Mustache = (-5,5,0.1)):
    new_z_samples = z_samples + \
                    Male*attribs_vectors['Male'] + \
                    Eyeglasses*attribs_vectors['Eyeglasses'] +\
                    Young*attribs_vectors['Young'] +\
                    Smiling*attribs_vectors['Smiling']+\
                    Blond_Hair*attribs_vectors['Blond_Hair'] +\
                    Pale_Skin*attribs_vectors['Pale_Skin'] +\
                    Mustache*attribs_vectors['Mustache']
    images = vae.decoder(new_z_samples)
    plt.figure(figsize=(4,4))
    plt.axis('off')
    plt.imshow(images[0])
posted @ 2021-05-23 22:10  盼小辉丶  阅读(977)  评论(0编辑  收藏  举报