变分自编码器(VAE)原理与实现(tensorflow2.x)

VAE介绍

变分自编码器(VAE)属于生成模型家族。VAE的生成器能够利用连续潜在空间的矢量产生有意义的输出。通过潜在矢量探索解码器输出的可能属性。
在GAN中,重点在于如何得出近似输入分布的模型。 VAE尝试对可解耦的连续潜在空间中的输入分布进行建模。
在VAE中,重点在于潜编码的变分推理。因此,VAE为潜在变量的学习和有效贝叶斯推理提供了合适的框架。
在结构上,VAE与自编码器相似。它也由编码器(也称为识别或推理模型)和解码器(也称为生成模型)组成。 VAE和自编码器都试图在学习潜矢量的同时重建输入数据。但是,与自编码器不同,VAE的潜在空间是连续的,并且解码器本身被用作生成模型。

VAE原理

在生成模型中,使用神经网络来逼近输入的真实分布:
x ∼ P θ ( x ) ( 1 ) x \sim P_θ(x) \qquad(1) xPθ(x)(1)
θ表示模型参数。
在机器学习中,为了执行特定的推理,希望找到 P θ ( x , z ) P_θ(x,z) Pθ(x,z),这是输入 x x x和潜变量 z z z之间的联合分布。潜变量是对可从输入中观察到的某些属性进行编码。如在名人面孔中,这些可能是面部表情,发型,头发颜色,性别等。
P θ ( x , z ) P_θ(x,z) Pθ(x,z)实际上是输入数据及其属性的分布。 P θ ( x ) P_θ(x) Pθ(x)可以从边缘分布计算:
P θ ( x ) = ∫ P θ ( x , z ) d z ( 2 ) P_θ(x)=\int P_θ(x,z)dz \qquad(2) Pθ(x)=Pθ(x,z)dz(2)
换句话说,考虑所有可能的属性,最终得到描述输入的分布。在名人面孔中,利用包含面部表情,发型,头发颜色和性别在内的特征,可以恢复描述名人面孔的分布。
问题在于该方程式没有解析形式或有效的估计量。因此,通过神经网络进行优化是不可行的。
使用贝叶斯定理,可以找到方程式(2)的替代表达式:
P θ ( x ) = ∫ P θ ( x ∣ z ) P ( z ) d z ( 3 ) P_θ(x)=\int P_θ(x|z)P(z)dz \qquad(3) Pθ(x)=Pθ(xz)P(z)dz(3)
P ( z ) P(z) P(z) z z z的先验分布。它不以任何观察为条件。如果 z z z是离散的并且 P θ ( x ∣ z ) P_θ(x|z) Pθ(xz)是高斯分布,则 P θ ( x ) P_θ(x) Pθ(x)是高斯分布的混合。如果 z z z是连续的,则高斯分布 P θ ( x ) P_θ(x) Pθ(x)无法预估。
在实践中,如果尝试在没有合适的损失函数的情况下建立近似 P θ ( x ∣ z ) P_θ(x|z) Pθ(xz)的神经网络,它将忽略 z z z并得出平凡解, P θ ( x ∣ z ) = P θ ( x ) P_θ(x|z)=P_θ(x) Pθ(xz)=Pθ(x)。因此,公式(3)不能提供 P θ ( x ) P_θ(x) Pθ(x)的良好估计。公式(2)也可以表示为:
P θ ( x ) = ∫ P θ ( z ∣ x ) P ( x ) d z ( 4 ) P_θ(x)=\int P_θ(z|x)P(x)dz \qquad(4) Pθ(x)=Pθ(zx)P(x)dz(4)
但是, P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)也难以求解。 VAE的目标是找到一个可估计的分布,该分布近似估计 P θ ( z ∣ x ) P_θ(z|x) Pθ(zx),即在给定输入 x x x的情况下对潜在编码 z z z的条件分布的估计。

变分推理

为了使 P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)易于处理,VAE引入了变分推断模型(编码器):
Q ϕ ( z ∣ x ) ≈ P θ ( z ∣ x ) ( 5 ) Q_\phi (z|x) \approx P_θ(z|x) \qquad(5) Qϕ(zx)Pθ(zx)(5)
Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)可很好地估计 P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)。它既可以参数化又易于处理。 可以通过深度神经网络优化参数 φ φ φ来近似 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)。 通常,将 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)选择为多元高斯分布:
Q ϕ ( z ∣ x ) = N ( z ; μ ( x ) , d i a g ( σ ( x ) 2 ) ) ( 6 ) Q_\phi (z|x)=\mathcal N(z;\mu(x),diag(\sigma(x)^2)) \qquad(6) Qϕ(zx)=N(z;μ(x),diag(σ(x)2))(6)
均值 μ ( x ) \mu(x) μ(x)和标准差 σ ( x ) \sigma (x) σ(x)均由编码器神经网络使用输入数据计算得出。对角矩阵表示 z z z中的元素间是相互独立的。

VAE核心方程

推理模型 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)从输入 x x x生成潜矢量 z z z Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)类似于自编码器模型中的编码器。另一方面, P θ ( x ∣ z ) P_θ(x|z) Pθ(xz)从潜码z重建输入。 P θ ( x ∣ z ) P_θ(x|z) Pθ(xz)的作用类似于自编码器模型中的解码器。要估算 P θ ( x ) P_θ(x) Pθ(x),必须确定其与 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx) P θ ( x ∣ z ) P_θ(x|z) Pθ(xz)的关系。
如果 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx) P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)的估计值,则Kullback-Leibler(KL)散度确定这两个条件密度之间的距离:
D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ∣ x ) ) = E z ∼ Q [ l o g Q ϕ ( z ∣ x ) − l o g P θ ( z ∣ x ) ] ( 7 ) D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z|x)] \qquad (7) DKL(Qϕ(zx)Pθ(zx))=EzQ[logQϕ(zx)logPθ(zx)](7)
使用贝叶斯定理:
P θ ( z ∣ x ) = P θ ( x ∣ z ) P θ ( z ) P θ ( x ) ( 8 ) P_θ(z|x)=\frac{P_θ(x|z)P_θ(z)}{P_θ(x)} \qquad(8) Pθ(zx)=Pθ(x)Pθ(xz)Pθ(z)(8)
通过公式(8)改写公式(7),同时由于 l o g P θ ( x ) logP_θ(x) logPθ(x)不依赖于 z ∼ Q z\sim Q zQ
D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ∣ x ) ) = E z ∼ Q [ l o g Q ϕ ( z ∣ x ) − l o g P θ ( x ∣ z ) − l o g P θ ( z ) ] + l o g P θ ( x ) ( 9 ) D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(x|z)-logP_θ(z)] + logP_θ(x)\qquad (9) DKL(Qϕ(zx)Pθ(zx))=EzQ[logQϕ(zx)logPθ(xz)logPθ(z)]+logPθ(x)(9)
重排上式并由:
E z ∼ Q [ l o g Q ϕ ( z ∣ x ) − l o g P θ ( z ) ] = D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ) ) ( 10 ) \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z)] = D_{KL}(Q_\phi (z|x) \| P_θ(z)) \qquad (10) EzQ[logQϕ(zx)logPθ(z)]=DKL(Qϕ(zx)Pθ(z))(10)
得到:
l o g P θ ( x ) − D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ∣ x ) ) = E z ∼ Q [ l o g P θ ( x ∣ z ) ] − D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ) ) ( 11 ) logP_θ(x)-D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logP_θ(x|z)] - D_{KL}(Q_\phi (z|x) \| P_θ(z))\qquad (11) logPθ(x)DKL(Qϕ(zx)Pθ(zx))=EzQ[logPθ(xz)]DKL(Qϕ(zx)Pθ(z))(11)
上式是VAE的核心。左侧项 P θ ( x ) P_θ(x) Pθ(x),它最大化地减少了 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx)与真实 P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)之间距离的差距。对数不会改变最大值(或最小值)的位置。给定一个可以很好地估计 P θ ( z ∣ x ) P_θ(z|x) Pθ(zx)的推断模型, D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ∣ x ) ) D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) DKL(Qϕ(zx)Pθ(zx))约为零。
右边的第一项 P θ ( z ∣ x ) ) P_θ(z|x)) Pθ(zx))类似于解码器,该解码器从推理模型中提取样本以重建输入。
第二项是 Q ϕ ( z ∣ x ) Q_\phi (z|x) Qϕ(zx) P θ ( z ) P_θ(z) Pθ(z)间的KL距离。公式的左侧也称为变化下界(evidence lower bound, ELBO)。由于KL始终为正,因此ELBO是 l o g P θ ( x ) logP_θ(x) logPθ(x)的下限。通过优化神经网络的参数 φ φ φ θ θ θ来最大化ELBO意味着:
1. D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ∣ x ) ) → 0 D_{KL}(Q_\phi (z|x) \| P_θ(z|x))\to 0 DKL(Qϕ(zx)Pθ(zx))0或在 z z z中对属性 x x x进行编码的推理模型得到优化。
2.右侧的 l o g P θ ( x ∣ z ) logP_θ(x|z) logPθ(xz)最大化,或者从潜在矢量 z z z重构 x x x时,解码器模型得到优化。

优化方式

公式的右侧具有有关VAE损失函数的两个重要信息。解码器项 E z ∼ Q [ l o g P θ ( x ∣ z ) ] \mathbb E_{z\sim Q}[logP_θ(x|z)] EzQ[logPθ(xz)]表示生成器从推理模型的输出中获取 z z z个样本以重构输入。最大化该项意味着将重建损失 L R \mathcal L_R LR最小化。如果图像(数据)分布假定为高斯分布,则可以使用MSE。
如果每个像素(数据)都被认为是伯努利分布,那么损失函数就是一个二元交叉熵。
第二项 − D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ) ) - D_{KL}(Q_\phi (z|x) \| P_θ(z)) DKL(Qϕ(zx)Pθ(z)),由于 Q ϕ Q_\phi Qϕ是高斯分布。通常 P θ ( z ) = P ( z ) = N ( 0 , 1 ) P_θ(z)=P(z)=\mathcal N(0,1) Pθ(z)=P(z)=N(0,1),也是均值为0且标准偏差等于1.0的高斯分布。KL项可以简化为:
− D K L ( Q ϕ ( z ∣ x ) ∥ P θ ( z ) ) = 1 2 ∑ j = 0 J ( 1 + l o g ( σ j ) 2 − ( μ j ) 2 − ( σ j ) 2 ) ( 12 ) - D_{KL}(Q_\phi (z|x) \| P_θ(z))=\frac{1}{2} \sum_{j=0}^J (1+log(\sigma_j)^2-(\mu_j)^2-(\sigma_j)^2)\qquad(12) DKL(Qϕ(zx)Pθ(z))=21j=0J(1+log(σj)2(μj)2(σj)2)(12)
其中 J J J z z z的维数。和都是通过推理模型计算得到的关于 x x x的函数。要最大化 − D K L -D_{KL} DKL:则 σ j → 1 \sigma_j \to 1 σj1 μ j → 0 \mu_j \to 0 μj0 P ( z ) = N ( 0 , 1 ) P(z)=\mathcal N(0,1) P(z)=N(0,1)的选择是由于各向同性单位高斯分布的性质,可以给定适当的函数将其变形为任意分布。
根据公式(12),KL损失 L K L \mathcal L_{KL} LKL D K L D_{KL} DKL。 综上,VAE损失函数定义为:
L V A E = L R + L K L ( 13 ) \mathcal L_{VAE}=\mathcal L_R + \mathcal L_{KL}\qquad (13) LVAE=LR+LKL(13)
给定编码器和解码器模型的情况下,在构建和训练VAE之前,还有一个问题需要解决。

重参数化技巧(Reparameterization trick)

下图左侧显示了VAE网络。编码器获取输入 x x x,并估计潜矢量z的多元高斯分布的均值 μ μ μ和标准差 σ σ σ。 解码器从潜矢量 z z z采样,以将输入重构为 x x x
VAE但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。
解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:
S a m p l e = μ + ε σ ( 14 ) Sample=\mu + εσ\qquad(14) Sample=μ+εσ(14)
如果 ε ε ε σ σ σ以矢量形式表示,则 ε σ εσ εσ是逐元素乘法。 使用公式(14),令采样好像直接来自于潜空间。 这项技术被称为重参数化技巧。
之后在输入端进行采样,可以使用熟悉的优化算法(例如SGD,Adam或RMSProp)来训练VAE网络。

VAE实现

为了便于可视化潜在编码,将 z z z的维度设置为2。编码器仅是两层MLP,第二层生成均值和对数方差。对数方差的使用是为了简化KL损耗和重新参数化技巧的计算。编码器的第三个输出是使用重参数化技巧进行的 z z z采样。在采样函数中, e 0.5 l o g σ 2 = σ 2 = σ e^{0.5log\sigma^2}=\sqrt{\sigma^2}=\sigma e0.5logσ2=σ2 =σ,因为 σ > 0 σ> 0 σ>0是高斯分布的标准偏差。
解码器也是两层MLP,它对 z z z的样本进行采样以近似输入。
VAE网络只是将编码器和解码器连接在一起。损失函数是重建损失和KL损失之和。使用Adam优化器。

导入库

from tensorflow import keras
import tensorflow as tf
import numpy as np
import os
import argparse
from matplotlib import pyplot as plt

重参数技巧

#reparameterization trick
#z = z_mean + sqrt(var) * eps
def sampling(args):
    """Reparameterization trick by sampling
    Reparameterization trick by sampling fr an isotropic unit Gaussian.
    #Arguments:
        args (tensor): mean and log of variance of Q(z|x)
    #Returns:
        z (tensor): sampled latent vector
    """
    z_mean,z_log_var = args
    batch = keras.backend.shape(z_mean)[0]
    dim = keras.backend.shape(z_mean)[1]

    epsilon = keras.backend.random_normal(shape=(batch,dim))
    return z_mean + keras.backend.exp(0.5 * z_log_var) * epsilon

绘制测试图片函数

def plot_results(models,
        data,
        batch_size=128,
        model_name='vae_mnist'):
    """Plots labels and MNIST digits as function of 2 dim latent vector
    Arguments:
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """
    encoder,decoder = models
    x_test,y_test = data
    xmin = ymin = -4
    xmax = ymax = +4
    os.makedirs(model_name,exist_ok=True)

    filename = os.path.join(model_name,'vae_mean.png')
    #display a 2D plot of the digit classes in the latent space
    z,_,_ = encoder.predict(x_test,batch_size=batch_size)
    plt.figure(figsize=(12,10))

    #axes x and y ranges
    axes = plt.gca()
    axes.set_xlim([xmin,xmax])
    axes.set_ylim([ymin,ymax])

    # subsampling to reduce density of points on the plot
    z = z[0::2]
    y_test = y_test[0::2]
    plt.scatter(z[:,0],z[:,1],marker='')
    for i,digit in enumerate(y_test):
        axes.annotate(digit,(z[i,0],z[i,1]))
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name,'digits_over_latent.png')
    #display a 30*30 2D mainfold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n,digit_size * n))
    #linearly spaced coordinates corresponding to the 2D plot of digit classes in the latent space
    #线性间隔的坐标,对应于潜在空间中数字类的二维图
    grid_x = np.linspace(-4,4,n)
    grid_y = np.linspace(-4,4,n)[::-1]

    for i,yi in enumerate(grid_x):
        for j,xi in enumerate(grid_y):
            z_sample = np.array([[xi,yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size,digit_size)
            figure[i * digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit
    
    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = (n-1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()

加载数据与超参数

# MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#超参数
input_shape = (original_dim,)
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50

VAE模型

#VAE model
#encoder
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = keras.layers.Dense(intermediate_dim,activation='relu')(inputs)
z_mean = keras.layers.Dense(latent_dim,name='z_mean')(x)
z_log_var = keras.layers.Dense(latent_dim,name='z_log_var')(x)

z = keras.layers.Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])

encoder = keras.Model(inputs,[z_mean,z_log_var,z],name='encoder')
encoder.summary()
keras.utils.plot_model(encoder,to_file='vae_mlp_encoder.png',show_shapes=True)

#decoder
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='z_sampling')
x = keras.layers.Dense(intermediate_dim,activation='relu')(latent_inputs)
outputs = keras.layers.Dense(original_dim,activation='sigmoid')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summary()
keras.utils.plot_model(decoder,to_file='vae_mlp_decoder.png',show_shapes=True)

outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs,outputs,name='vae_mpl')

模型训练

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load tf model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use binary cross entropy instead of mse (default)"
    parser.add_argument("--bce", help=help_, action='store_true')
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test)
    
    #VAE loss = mse_loss or xent_loss + kl_loss
    if args.bce:
        reconstruction_loss = keras.losses.binary_crossentropy(inputs,outputs)
    else:
        reconstruction_loss = keras.losses.mse(inputs,outputs)
    
    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - keras.backend.square(z_mean) - keras.backend.exp(z_log_var)
    kl_loss = keras.backend.sum(kl_loss,axis=-1)
    kl_loss *= -0.5
    vae_loss = keras.backend.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    vae.summary()
    keras.utils.plot_model(vae,to_file='vae_mlp.png',show_shapes=True)
    save_dir = 'vae_mlp_weights'
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    if args.weights:
        filepath = os.path.join(save_dir,args.weights)
        vae = vae.load_weights(filepath)
    else:
        #train
        vae.fit(x_train,
                epochs=epochs,
                batch_size=batch_size,
                validation_data=(x_test,None))
        filepath = os.path.join(save_dir,'vae_mlp.mnist.tf')
        vae.save_weights(filepath)
    plot_results(models,data,batch_size=batch_size,model_name='vae_mlp')

测试经过训练的解码器

在训练了VAE网络之后,可以丢弃推理模型。为了生成新的有意义的输出,从用于生成 ε ε ε的高斯分布中抽取样本:

解码器

效果展示

潜矢量可视化

潜矢量可视化

图片生成

图片生成

posted @ 2020-10-24 20:32  盼小辉丶  阅读(879)  评论(0编辑  收藏  举报