变分自动编码器

变分自动编码器

Diederik Kingma和Max Welling于2013年推出了自动编码器的另一个重要类别,并迅速成为最受欢迎的自动编码器类型之一:变分自动编码器

它们与目前为止的自动编码器有很大的不同,它们具有以下特殊的地方:

  • 它们是概率自动编码器,这意味着即使在训练后,它们的输出会部分由概率决定(与仅在训练期间使用随机性的去噪自动编码器相反)
  • 它们是生成式自动编码器,这意味着它们可以生成看起来像是从训练集中采样的新实例

这两个属性使得它们与RBM相当类似,但是它们更容易训练,并且采样过程要快得多(使用RMB,需要等到网络稳定到“热平衡”,然后才能采样新实例)。变分自动编码器执行变分贝叶斯推理,这是执行近似贝叶斯推理的有效方法

变分自动编码器不是直接为给定输入生成编码,而是编码器产生平均编码\(\mu\)和标准差\(\sigma\)。然后实际编码是从均值\(\mu\)和标准差\(\sigma\)的高斯分布中随机采样的。之后解码器正常解码采样到的编码

在训练过程中,成本函数会迫使编码逐渐地在编码空间中内移动,最终看起来像高斯点云。一个很好的结果是,在训练了变分自动编码器之后,可以轻松地生成一个新实例:只需从高斯分布中采样一个随机编码,对其进行解码,然后就伪造出来了

成本函数由两部分组成:第一部分是通常的重构损失,它会迫使自动编码器重现其输入。第二个是潜在损失,它使自动编码器的编码看起像是从简单地高斯分布中采样得到的:它是目标分布(高斯分布)和编码的实际分布之间的KL散度。在数学上比稀疏自动编码器要复杂一些,特别是由于高斯噪声,它限制了可以传输到编码层的信息量(因此迫使自动编码器学习有用的特征)。

变分自动编码器的潜在损失\(\mathcal{L}=-\frac12\sum_{i=1}^K1+\log{(\sigma_i^2)}-\sigma_i^2-\mu_i^2\)
在这个等式中,\(\mathcal{L}\)是潜在损失,\(n\)是编码的维度,\(\mu_i\)\(\sigma_i\)是编码中第\(i\)个分量的均值和标准差。向量\(\mu\)\(\sigma\)(包含所有\(\mu_i\)\(\sigma_i\))由编码器输出

变分自动编码器架构的通常调整是使编码器输出\(\gamma=\log(\sigma^2)\)而不是\(\sigma\)。然后如下列公式计算潜在损失,这种方法在数值上更稳定,而且可以加快训练速度$$\mathcal{L}=-\frac12\sum_{i=1}K1+\gamma_i-\exp(\gamma_i)-\mu_i2$$

下面为Fashion MNIST构建一个变分自动编码器,使用\(gamma\)调整,首先,给定\(\mu\)\(\gamma\),需要定义一个自定义层来采样编码

from tensorflow import keras
import tensorflow as tf

K = keras.backend


class Sampling(keras.layers.Layer):
    def call(self, inputs):
        mean, log_var = inputs
        return K.random_normal(tf.shape(log_var)) * K.exp(log_var / 2)

Sampling层接受两个输入:mean\((\mu)\)和log_var\((\gamma)\),它使用函数K.random_normal()从正态分布中采样一个均值为0和标准差为1的随机向量(与\(\gamma\)形状相同),然后将其乘以\(\exp(\gamma/2)\)(等于\(\sigma\)),最后将\(\mu\)加起来并返回结果。该方法从均值\(\mu\)和标准差\(\sigma\)的正态分布中采样一个编码向量

接下来使用函数API来创建编码器,因为模型不是完全顺序的:

codings_size = 10

inputs = keras.layers.Input(shape=[28, 28])
z = keras.layers.Flatten()(inputs)
z = keras.layers.Dense(150, activation='gelu')(z)
z = keras.layers.Dense(100, activation='gelu')(z)
codings_mean = keras.layers.Dense(codings_size)(z)
codings_log_var = keras.layers.Dense(codings_size)(z)
codings = Sampling()([codings_mean, codings_log_var])
variational_encoder = keras.Model(inputs=[inputs], outputs=[codings_mean, codings_log_var, codings])

输出codings_mean\((\mu)和codings_log_var\)(\gamma)$的Dense层具有相同的形状(第二个Dense输出)。将codings_mean和codings_log_var都传递给Sampling层。最后,如果要检查codings_mean和codings_log_var的值,variational_encoder模型具有三个输出,需要使用的是最后一个codings,现在开始构建解码器:

decoder_inputs = keras.layers.Input(shape=[codings_size])
x = keras.layers.Dense(100, activation='gelu')(decoder_inputs)
x = keras.layers.Dense(150, activation='gelu')(x)
x = keras.layers.Dense(28 * 28, activation='sigmoid')(x)
outputs = keras.layers.Reshape([28, 28])(x)
variational_decoder = keras.Model(inputs=[decoder_inputs], outputs=[outputs])

对于此编码器,可以使用顺序API而不是函数式API,因为它实际上只是一个简单的层堆栈。最后,建立变分自动编码器模型:

_, _, codings = variational_encoder(inputs)
reconstructions = variational_decoder(codings)
variational_ae = keras.Model(inputs=[inputs], outputs=[reconstructions])

最后,必须加上潜在损失和重构损失

latent_loss = -0.5 * K.sum(
    1 + codings_log_var - K.exp(codings_log_var) - K.square(codings_mean), axis=-1
)
variational_ae.add_loss(K.mean(latent_loss) / 784.)
variational_ae.compile(loss='binary_crossentropy', optimizer='rmsprop')

首先应用公式来计算该批次中每个实例的潜在损失(在最后一个轴求和)。然后,计算该批次中所有实例的平均损失,将结果除以784,以确保它和重构损失相比具有合适的比例标度。变分自动编码器的重建损失应该是像素重建误差的总和,但是当Keras计算“binary_crossentropy”损失时,它计算所有784个像素的均值,而不是总和。因此重构损失比需要的少784倍。可以定义一个损失来计算总和而不是平均值,但是把潜在损失除以784更为简单(最终损失要比其应该的小784倍,但这只是意味着需要使用更大一点的学习率)

在这里使用RMSProp优化器,该优化器在这个示例下效果很好,下面训练自动编码器

fashion_mnist = keras.datasets.fashion_mnist
(X_train_all, y_train_all), (X_test, y_test) = fashion_mnist.load_data()
X_valid, X_train = X_train_all[:5000] / 255., X_train_all[5000:] / 255.
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
history = variational_ae.fit(X_train, X_train, epochs=50, batch_size=32, validation_data=(X_valid, X_valid))
Epoch 1/50
1719/1719 [==============================] - 13s 7ms/step - loss: 0.4348 - val_loss: 0.3956
Epoch 2/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3915 - val_loss: 0.3833
Epoch 3/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3822 - val_loss: 0.3752
Epoch 4/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3746 - val_loss: 0.3687
Epoch 5/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3699 - val_loss: 0.3657
Epoch 6/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3672 - val_loss: 0.3646
Epoch 7/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3652 - val_loss: 0.3626
Epoch 8/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3638 - val_loss: 0.3607
Epoch 9/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3627 - val_loss: 0.3598
Epoch 10/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3619 - val_loss: 0.3584
Epoch 11/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3610 - val_loss: 0.3578
Epoch 12/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3603 - val_loss: 0.3577
Epoch 13/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3597 - val_loss: 0.3578
Epoch 14/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3592 - val_loss: 0.3555
Epoch 15/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3587 - val_loss: 0.3564
Epoch 16/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3584 - val_loss: 0.3558
Epoch 17/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3581 - val_loss: 0.3564
Epoch 18/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3581 - val_loss: 0.3561
Epoch 19/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3578 - val_loss: 0.3560
Epoch 20/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3576 - val_loss: 0.3542
Epoch 21/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3574 - val_loss: 0.3548
Epoch 22/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3571 - val_loss: 0.3552
Epoch 23/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3569 - val_loss: 0.3545
Epoch 24/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3567 - val_loss: 0.3555
Epoch 25/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3566 - val_loss: 0.3543
Epoch 26/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3562 - val_loss: 0.3536
Epoch 27/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3561 - val_loss: 0.3563
Epoch 28/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3561 - val_loss: 0.3550
Epoch 29/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3560 - val_loss: 0.3539
Epoch 30/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3562 - val_loss: 0.3539
Epoch 31/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3559 - val_loss: 0.3537
Epoch 32/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3555 - val_loss: 0.3533
Epoch 33/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3555 - val_loss: 0.3524
Epoch 34/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3541
Epoch 35/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3524
Epoch 36/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3531
Epoch 37/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3550 - val_loss: 0.3525
Epoch 38/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3550 - val_loss: 0.3531
Epoch 39/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3550 - val_loss: 0.3532
Epoch 40/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3550 - val_loss: 0.3517
Epoch 41/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3525
Epoch 42/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3523
Epoch 43/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3547 - val_loss: 0.3510
Epoch 44/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3545 - val_loss: 0.3543
Epoch 45/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3532
Epoch 46/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3547 - val_loss: 0.3513
Epoch 47/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3549 - val_loss: 0.3523
Epoch 48/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3546 - val_loss: 0.3519
Epoch 49/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3549 - val_loss: 0.3510
Epoch 50/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3547 - val_loss: 0.3517

生成Fashion MNIST图像

使用变分自动编码器来生成看起来时尚的物品。需要做的就是从高斯分布中采样随机编码并对它们进行解码

coding = tf.random.normal(shape=[12, codings_size])
images = variational_decoder(coding).numpy()
import matplotlib.pyplot as plt


def plot_image(image):
    plt.imshow(image, cmap='binary')
    plt.axis('off')


fig = plt.figure(figsize=(12 * 1.5, 3))
for image_index in range(12):
    plt.subplot(3, 4, image_index + 1)
    plot_image(images[image_index])


可变自动编码器使得语义插值成为可能:可以在编码级别进行插值,而不是在像素级别插值两个图像(看起来好像两个图像被叠加了一样)。首先让两个图像通过编码器,然后对获得的两个编码进行插值,最后对插值的编码进行解码来获得最终图像。它看起来像是常规的Fashion MNIST图像,但它是原始图像之间的中间图像,在下面的代码示例,使用刚刚生成的12个编码器,把它们组织在\(3\times4\)网格中,然后使用TensorFlow的tf.image.resize()函数将该网格的大小调整为\(5\times7\)。默认情况下,resize()函数会执行双线性插值,因此每隔一行和一列会包含插值编码。然后,使用解码器生成所有图像:

codings_grid = tf.reshape(coding, [1, 3, 4, codings_size])
larger_grid = tf.image.resize(codings_grid, size=[5, 7])
interpolated_codings = tf.reshape(larger_grid, [-1, codings_size])
images = variational_decoder(interpolated_codings).numpy()
fig = plt.figure(figsize=(6 * 1.5, 6))
for image_index in range(35):
    plt.subplot(5, 7, image_index + 1)
    plot_image(images[image_index])


posted @ 2022-01-10 20:02  里列昂遗失的记事本  阅读(402)  评论(0编辑  收藏  举报