变分自编码器(VAE)原理与实现(tensorflow2.x)
变分自编码器(VAE)原理与实现(tensorflow2.x)
VAE介绍
变分自编码器(VAE)属于生成模型家族。VAE的生成器能够利用连续潜在空间的矢量产生有意义的输出。通过潜在矢量探索解码器输出的可能属性。
在GAN中,重点在于如何得出近似输入分布的模型。 VAE尝试对可解耦的连续潜在空间中的输入分布进行建模。
在VAE中,重点在于潜编码的变分推理。因此,VAE为潜在变量的学习和有效贝叶斯推理提供了合适的框架。
在结构上,VAE与自编码器相似。它也由编码器(也称为识别或推理模型)和解码器(也称为生成模型)组成。 VAE和自编码器都试图在学习潜矢量的同时重建输入数据。但是,与自编码器不同,VAE的潜在空间是连续的,并且解码器本身被用作生成模型。
VAE原理
在生成模型中,使用神经网络来逼近输入的真实分布:
x
∼
P
θ
(
x
)
(
1
)
x \sim P_θ(x) \qquad(1)
x∼Pθ(x)(1)
θ表示模型参数。
在机器学习中,为了执行特定的推理,希望找到
P
θ
(
x
,
z
)
P_θ(x,z)
Pθ(x,z),这是输入
x
x
x和潜变量
z
z
z之间的联合分布。潜变量是对可从输入中观察到的某些属性进行编码。如在名人面孔中,这些可能是面部表情,发型,头发颜色,性别等。
P
θ
(
x
,
z
)
P_θ(x,z)
Pθ(x,z)实际上是输入数据及其属性的分布。
P
θ
(
x
)
P_θ(x)
Pθ(x)可以从边缘分布计算:
P
θ
(
x
)
=
∫
P
θ
(
x
,
z
)
d
z
(
2
)
P_θ(x)=\int P_θ(x,z)dz \qquad(2)
Pθ(x)=∫Pθ(x,z)dz(2)
换句话说,考虑所有可能的属性,最终得到描述输入的分布。在名人面孔中,利用包含面部表情,发型,头发颜色和性别在内的特征,可以恢复描述名人面孔的分布。
问题在于该方程式没有解析形式或有效的估计量。因此,通过神经网络进行优化是不可行的。
使用贝叶斯定理,可以找到方程式(2)的替代表达式:
P
θ
(
x
)
=
∫
P
θ
(
x
∣
z
)
P
(
z
)
d
z
(
3
)
P_θ(x)=\int P_θ(x|z)P(z)dz \qquad(3)
Pθ(x)=∫Pθ(x∣z)P(z)dz(3)
P
(
z
)
P(z)
P(z)是
z
z
z的先验分布。它不以任何观察为条件。如果
z
z
z是离散的并且
P
θ
(
x
∣
z
)
P_θ(x|z)
Pθ(x∣z)是高斯分布,则
P
θ
(
x
)
P_θ(x)
Pθ(x)是高斯分布的混合。如果
z
z
z是连续的,则高斯分布
P
θ
(
x
)
P_θ(x)
Pθ(x)无法预估。
在实践中,如果尝试在没有合适的损失函数的情况下建立近似
P
θ
(
x
∣
z
)
P_θ(x|z)
Pθ(x∣z)的神经网络,它将忽略
z
z
z并得出平凡解,
P
θ
(
x
∣
z
)
=
P
θ
(
x
)
P_θ(x|z)=P_θ(x)
Pθ(x∣z)=Pθ(x)。因此,公式(3)不能提供
P
θ
(
x
)
P_θ(x)
Pθ(x)的良好估计。公式(2)也可以表示为:
P
θ
(
x
)
=
∫
P
θ
(
z
∣
x
)
P
(
x
)
d
z
(
4
)
P_θ(x)=\int P_θ(z|x)P(x)dz \qquad(4)
Pθ(x)=∫Pθ(z∣x)P(x)dz(4)
但是,
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)也难以求解。 VAE的目标是找到一个可估计的分布,该分布近似估计
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x),即在给定输入
x
x
x的情况下对潜在编码
z
z
z的条件分布的估计。
变分推理
为了使
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)易于处理,VAE引入了变分推断模型(编码器):
Q
ϕ
(
z
∣
x
)
≈
P
θ
(
z
∣
x
)
(
5
)
Q_\phi (z|x) \approx P_θ(z|x) \qquad(5)
Qϕ(z∣x)≈Pθ(z∣x)(5)
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)可很好地估计
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)。它既可以参数化又易于处理。 可以通过深度神经网络优化参数
φ
φ
φ来近似
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)。 通常,将
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)选择为多元高斯分布:
Q
ϕ
(
z
∣
x
)
=
N
(
z
;
μ
(
x
)
,
d
i
a
g
(
σ
(
x
)
2
)
)
(
6
)
Q_\phi (z|x)=\mathcal N(z;\mu(x),diag(\sigma(x)^2)) \qquad(6)
Qϕ(z∣x)=N(z;μ(x),diag(σ(x)2))(6)
均值
μ
(
x
)
\mu(x)
μ(x)和标准差
σ
(
x
)
\sigma (x)
σ(x)均由编码器神经网络使用输入数据计算得出。对角矩阵表示
z
z
z中的元素间是相互独立的。
VAE核心方程
推理模型
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)从输入
x
x
x生成潜矢量
z
z
z。
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)类似于自编码器模型中的编码器。另一方面,
P
θ
(
x
∣
z
)
P_θ(x|z)
Pθ(x∣z)从潜码z重建输入。
P
θ
(
x
∣
z
)
P_θ(x|z)
Pθ(x∣z)的作用类似于自编码器模型中的解码器。要估算
P
θ
(
x
)
P_θ(x)
Pθ(x),必须确定其与
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)和
P
θ
(
x
∣
z
)
P_θ(x|z)
Pθ(x∣z)的关系。
如果
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)是
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)的估计值,则Kullback-Leibler(KL)散度确定这两个条件密度之间的距离:
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
∣
x
)
)
=
E
z
∼
Q
[
l
o
g
Q
ϕ
(
z
∣
x
)
−
l
o
g
P
θ
(
z
∣
x
)
]
(
7
)
D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z|x)] \qquad (7)
DKL(Qϕ(z∣x)∥Pθ(z∣x))=Ez∼Q[logQϕ(z∣x)−logPθ(z∣x)](7)
使用贝叶斯定理:
P
θ
(
z
∣
x
)
=
P
θ
(
x
∣
z
)
P
θ
(
z
)
P
θ
(
x
)
(
8
)
P_θ(z|x)=\frac{P_θ(x|z)P_θ(z)}{P_θ(x)} \qquad(8)
Pθ(z∣x)=Pθ(x)Pθ(x∣z)Pθ(z)(8)
通过公式(8)改写公式(7),同时由于
l
o
g
P
θ
(
x
)
logP_θ(x)
logPθ(x)不依赖于
z
∼
Q
z\sim Q
z∼Q:
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
∣
x
)
)
=
E
z
∼
Q
[
l
o
g
Q
ϕ
(
z
∣
x
)
−
l
o
g
P
θ
(
x
∣
z
)
−
l
o
g
P
θ
(
z
)
]
+
l
o
g
P
θ
(
x
)
(
9
)
D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(x|z)-logP_θ(z)] + logP_θ(x)\qquad (9)
DKL(Qϕ(z∣x)∥Pθ(z∣x))=Ez∼Q[logQϕ(z∣x)−logPθ(x∣z)−logPθ(z)]+logPθ(x)(9)
重排上式并由:
E
z
∼
Q
[
l
o
g
Q
ϕ
(
z
∣
x
)
−
l
o
g
P
θ
(
z
)
]
=
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
)
)
(
10
)
\mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z)] = D_{KL}(Q_\phi (z|x) \| P_θ(z)) \qquad (10)
Ez∼Q[logQϕ(z∣x)−logPθ(z)]=DKL(Qϕ(z∣x)∥Pθ(z))(10)
得到:
l
o
g
P
θ
(
x
)
−
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
∣
x
)
)
=
E
z
∼
Q
[
l
o
g
P
θ
(
x
∣
z
)
]
−
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
)
)
(
11
)
logP_θ(x)-D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logP_θ(x|z)] - D_{KL}(Q_\phi (z|x) \| P_θ(z))\qquad (11)
logPθ(x)−DKL(Qϕ(z∣x)∥Pθ(z∣x))=Ez∼Q[logPθ(x∣z)]−DKL(Qϕ(z∣x)∥Pθ(z))(11)
上式是VAE的核心。左侧项
P
θ
(
x
)
P_θ(x)
Pθ(x),它最大化地减少了
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)与真实
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)之间距离的差距。对数不会改变最大值(或最小值)的位置。给定一个可以很好地估计
P
θ
(
z
∣
x
)
P_θ(z|x)
Pθ(z∣x)的推断模型,
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
∣
x
)
)
D_{KL}(Q_\phi (z|x) \| P_θ(z|x))
DKL(Qϕ(z∣x)∥Pθ(z∣x))约为零。
右边的第一项
P
θ
(
z
∣
x
)
)
P_θ(z|x))
Pθ(z∣x))类似于解码器,该解码器从推理模型中提取样本以重建输入。
第二项是
Q
ϕ
(
z
∣
x
)
Q_\phi (z|x)
Qϕ(z∣x)与
P
θ
(
z
)
P_θ(z)
Pθ(z)间的KL距离。公式的左侧也称为变化下界(evidence lower bound, ELBO)。由于KL始终为正,因此ELBO是
l
o
g
P
θ
(
x
)
logP_θ(x)
logPθ(x)的下限。通过优化神经网络的参数
φ
φ
φ和
θ
θ
θ来最大化ELBO意味着:
1.
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
∣
x
)
)
→
0
D_{KL}(Q_\phi (z|x) \| P_θ(z|x))\to 0
DKL(Qϕ(z∣x)∥Pθ(z∣x))→0或在
z
z
z中对属性
x
x
x进行编码的推理模型得到优化。
2.右侧的
l
o
g
P
θ
(
x
∣
z
)
logP_θ(x|z)
logPθ(x∣z)最大化,或者从潜在矢量
z
z
z重构
x
x
x时,解码器模型得到优化。
优化方式
公式的右侧具有有关VAE损失函数的两个重要信息。解码器项
E
z
∼
Q
[
l
o
g
P
θ
(
x
∣
z
)
]
\mathbb E_{z\sim Q}[logP_θ(x|z)]
Ez∼Q[logPθ(x∣z)]表示生成器从推理模型的输出中获取
z
z
z个样本以重构输入。最大化该项意味着将重建损失
L
R
\mathcal L_R
LR最小化。如果图像(数据)分布假定为高斯分布,则可以使用MSE。
如果每个像素(数据)都被认为是伯努利分布,那么损失函数就是一个二元交叉熵。
第二项
−
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
)
)
- D_{KL}(Q_\phi (z|x) \| P_θ(z))
−DKL(Qϕ(z∣x)∥Pθ(z)),由于
Q
ϕ
Q_\phi
Qϕ是高斯分布。通常
P
θ
(
z
)
=
P
(
z
)
=
N
(
0
,
1
)
P_θ(z)=P(z)=\mathcal N(0,1)
Pθ(z)=P(z)=N(0,1),也是均值为0且标准偏差等于1.0的高斯分布。KL项可以简化为:
−
D
K
L
(
Q
ϕ
(
z
∣
x
)
∥
P
θ
(
z
)
)
=
1
2
∑
j
=
0
J
(
1
+
l
o
g
(
σ
j
)
2
−
(
μ
j
)
2
−
(
σ
j
)
2
)
(
12
)
- D_{KL}(Q_\phi (z|x) \| P_θ(z))=\frac{1}{2} \sum_{j=0}^J (1+log(\sigma_j)^2-(\mu_j)^2-(\sigma_j)^2)\qquad(12)
−DKL(Qϕ(z∣x)∥Pθ(z))=21j=0∑J(1+log(σj)2−(μj)2−(σj)2)(12)
其中
J
J
J是
z
z
z的维数。和都是通过推理模型计算得到的关于
x
x
x的函数。要最大化
−
D
K
L
-D_{KL}
−DKL:则
σ
j
→
1
\sigma_j \to 1
σj→1,
μ
j
→
0
\mu_j \to 0
μj→0。
P
(
z
)
=
N
(
0
,
1
)
P(z)=\mathcal N(0,1)
P(z)=N(0,1)的选择是由于各向同性单位高斯分布的性质,可以给定适当的函数将其变形为任意分布。
根据公式(12),KL损失
L
K
L
\mathcal L_{KL}
LKL为
D
K
L
D_{KL}
DKL。 综上,VAE损失函数定义为:
L
V
A
E
=
L
R
+
L
K
L
(
13
)
\mathcal L_{VAE}=\mathcal L_R + \mathcal L_{KL}\qquad (13)
LVAE=LR+LKL(13)
给定编码器和解码器模型的情况下,在构建和训练VAE之前,还有一个问题需要解决。
重参数化技巧(Reparameterization trick)
下图左侧显示了VAE网络。编码器获取输入
x
x
x,并估计潜矢量z的多元高斯分布的均值
μ
μ
μ和标准差
σ
σ
σ。 解码器从潜矢量
z
z
z采样,以将输入重构为
x
x
x。
但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。
解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:
S
a
m
p
l
e
=
μ
+
ε
σ
(
14
)
Sample=\mu + εσ\qquad(14)
Sample=μ+εσ(14)
如果
ε
ε
ε和
σ
σ
σ以矢量形式表示,则
ε
σ
εσ
εσ是逐元素乘法。 使用公式(14),令采样好像直接来自于潜空间。 这项技术被称为重参数化技巧。
之后在输入端进行采样,可以使用熟悉的优化算法(例如SGD,Adam或RMSProp)来训练VAE网络。
VAE实现
为了便于可视化潜在编码,将
z
z
z的维度设置为2。编码器仅是两层MLP,第二层生成均值和对数方差。对数方差的使用是为了简化KL损耗和重新参数化技巧的计算。编码器的第三个输出是使用重参数化技巧进行的
z
z
z采样。在采样函数中,
e
0.5
l
o
g
σ
2
=
σ
2
=
σ
e^{0.5log\sigma^2}=\sqrt{\sigma^2}=\sigma
e0.5logσ2=σ2=σ,因为
σ
>
0
σ> 0
σ>0是高斯分布的标准偏差。
解码器也是两层MLP,它对
z
z
z的样本进行采样以近似输入。
VAE网络只是将编码器和解码器连接在一起。损失函数是重建损失和KL损失之和。使用Adam优化器。
导入库
from tensorflow import keras
import tensorflow as tf
import numpy as np
import os
import argparse
from matplotlib import pyplot as plt
重参数技巧
#reparameterization trick
#z = z_mean + sqrt(var) * eps
def sampling(args):
"""Reparameterization trick by sampling
Reparameterization trick by sampling fr an isotropic unit Gaussian.
#Arguments:
args (tensor): mean and log of variance of Q(z|x)
#Returns:
z (tensor): sampled latent vector
"""
z_mean,z_log_var = args
batch = keras.backend.shape(z_mean)[0]
dim = keras.backend.shape(z_mean)[1]
epsilon = keras.backend.random_normal(shape=(batch,dim))
return z_mean + keras.backend.exp(0.5 * z_log_var) * epsilon
绘制测试图片函数
def plot_results(models,
data,
batch_size=128,
model_name='vae_mnist'):
"""Plots labels and MNIST digits as function of 2 dim latent vector
Arguments:
models (tuple): encoder and decoder models
data (tuple): test data and label
batch_size (int): prediction batch size
model_name (string): which model is using this function
"""
encoder,decoder = models
x_test,y_test = data
xmin = ymin = -4
xmax = ymax = +4
os.makedirs(model_name,exist_ok=True)
filename = os.path.join(model_name,'vae_mean.png')
#display a 2D plot of the digit classes in the latent space
z,_,_ = encoder.predict(x_test,batch_size=batch_size)
plt.figure(figsize=(12,10))
#axes x and y ranges
axes = plt.gca()
axes.set_xlim([xmin,xmax])
axes.set_ylim([ymin,ymax])
# subsampling to reduce density of points on the plot
z = z[0::2]
y_test = y_test[0::2]
plt.scatter(z[:,0],z[:,1],marker='')
for i,digit in enumerate(y_test):
axes.annotate(digit,(z[i,0],z[i,1]))
plt.xlabel('z[0]')
plt.ylabel('z[1]')
plt.savefig(filename)
plt.show()
filename = os.path.join(model_name,'digits_over_latent.png')
#display a 30*30 2D mainfold of digits
n = 30
digit_size = 28
figure = np.zeros((digit_size * n,digit_size * n))
#linearly spaced coordinates corresponding to the 2D plot of digit classes in the latent space
#线性间隔的坐标,对应于潜在空间中数字类的二维图
grid_x = np.linspace(-4,4,n)
grid_y = np.linspace(-4,4,n)[::-1]
for i,yi in enumerate(grid_x):
for j,xi in enumerate(grid_y):
z_sample = np.array([[xi,yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size,digit_size)
figure[i * digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit
plt.figure(figsize=(10, 10))
start_range = digit_size // 2
end_range = (n-1) * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap='Greys_r')
plt.savefig(filename)
plt.show()
加载数据与超参数
# MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
#超参数
input_shape = (original_dim,)
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50
VAE模型
#VAE model
#encoder
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = keras.layers.Dense(intermediate_dim,activation='relu')(inputs)
z_mean = keras.layers.Dense(latent_dim,name='z_mean')(x)
z_log_var = keras.layers.Dense(latent_dim,name='z_log_var')(x)
z = keras.layers.Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])
encoder = keras.Model(inputs,[z_mean,z_log_var,z],name='encoder')
encoder.summary()
keras.utils.plot_model(encoder,to_file='vae_mlp_encoder.png',show_shapes=True)
#decoder
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='z_sampling')
x = keras.layers.Dense(intermediate_dim,activation='relu')(latent_inputs)
outputs = keras.layers.Dense(original_dim,activation='sigmoid')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summary()
keras.utils.plot_model(decoder,to_file='vae_mlp_decoder.png',show_shapes=True)
outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs,outputs,name='vae_mpl')
模型训练
if __name__ == '__main__':
parser = argparse.ArgumentParser()
help_ = "Load tf model trained weights"
parser.add_argument("-w", "--weights", help=help_)
help_ = "Use binary cross entropy instead of mse (default)"
parser.add_argument("--bce", help=help_, action='store_true')
args = parser.parse_args()
models = (encoder, decoder)
data = (x_test, y_test)
#VAE loss = mse_loss or xent_loss + kl_loss
if args.bce:
reconstruction_loss = keras.losses.binary_crossentropy(inputs,outputs)
else:
reconstruction_loss = keras.losses.mse(inputs,outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_var - keras.backend.square(z_mean) - keras.backend.exp(z_log_var)
kl_loss = keras.backend.sum(kl_loss,axis=-1)
kl_loss *= -0.5
vae_loss = keras.backend.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
vae.summary()
keras.utils.plot_model(vae,to_file='vae_mlp.png',show_shapes=True)
save_dir = 'vae_mlp_weights'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
if args.weights:
filepath = os.path.join(save_dir,args.weights)
vae = vae.load_weights(filepath)
else:
#train
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test,None))
filepath = os.path.join(save_dir,'vae_mlp.mnist.tf')
vae.save_weights(filepath)
plot_results(models,data,batch_size=batch_size,model_name='vae_mlp')
测试经过训练的解码器
在训练了VAE网络之后,可以丢弃推理模型。为了生成新的有意义的输出,从用于生成 ε ε ε的高斯分布中抽取样本:
效果展示
潜矢量可视化
图片生成