变分自编码器VAE的由来和简单实现(PyTorch)

​ 之前经常遇到变分自编码器的概念(\(VAE\)),但是自己对于这个概念总是模模糊糊,今天就系统的对\(VAE\)进行一些整理和回顾。

VAE的由来

​ 假设有一个目标数据\(X=\{X_1,X_2,\cdots,X_n\}\),我们想生成一些数据,即生成\(\hat{X}=\{\hat{X_1},\hat{X_2},\cdots,\hat{X_n}\}\),其分布与\(X\)相同。

​ 但是实际上,这样存在一些问题,第一是我们如何将生成的\(\hat{X}\)\(X\)一一对应,这就需要我们采用更为精巧的度量方式,即如何度量两个分布之间的距离;第二是我们如何生成新的\(\hat{X}\),按照朴素的想法,我们可以构造一个函数\(G\),使得\(\hat{X}=G(Z)\) ,如果能构造出这个\(G\),我们就可以通过一个任意的\(Z\),来生成\(\hat{X}\) ,而这里的\(Z\),可以取一个已知的分布,比如正态分布。

目前的问题

​ 目前的问题转化为了如何构造\(G\),以及如何检验我们生成的\(\hat{X}\)是否和\(X\)具有同分布。在\(GAN\)中,这里的\(G\)和分布的相似度衡量都用神经网络搞定了,一个叫做\(generator\),一个叫做\(discriminator\),这二者互相拮抗,最终使得分布越来接近。

​ 而在我们目前的问题中,\(VAE\)提供了另外一种思路,沿着AutoEncoder的想法,AutoEncoder是通过\(encoder\)把image \(a\)编码为vector,叫做\(latent{\ }represention\) ,再通过\(decoder\)\(latent{\ }space\)转为\(\hat{a}\) ,\(\hat{a}\)\(a\)的重建图像。

​ 但是AE针对每张图片生成的\(latent{\ }code\)并没有可解释性,即sample两个\(latent{\ } code\)之间的点输入\(decoder\),得到的结果并不一定具有跟这两个\(latent code\)相关的特征。为了解决这个问题,提出了VAE:不再采用vector来建模一个\(latent{\ }code\),而是利用一个带有noise的高斯分布来表示。直观的理解,在加入noise之后,就有机会将训练时候train的\(latent{\ }code\)在其latent space下赋予一定的变化能力,使latent space变得更加连续,从而可以在其中采样从而生成新的图片。

​ 我们之前生成的\(Z=\{Z_1,Z_2,\cdots,Z_n\}\),现在不再单单生成一个\(Z\),而是生成两个vector,分别记为\(M=\{{\mu_1},{\mu_2},\cdots\,{\mu_n}\}\),\(\Sigma=\{ {\sigma_1},{\sigma_2},\cdots,\sigma_n\}\),分别代表新生成latent code的高斯分布的均值和方差。在sample的时候就只需要根据从标准正态分布\(\mathcal{N}(0,1)\)中采样一个\(e_i\),\(e_i\)来自于\(E=\{e_1,e_2,\cdots,e_n\}\),然后利用\(c_i=e_i*exp({\sigma_i})+\mu_i\)(\(reparameterization{\ }trick\)),就得到了我们所需的\(c_i\)\(c_i\)即组成我们需要的\(Z\)=\(\{c_1,c_2,\cdots,c_n\}\)

​ 这里一方面希望\(VAE\)能够生成尽可能丰富的数据,因此训练的时候希望在高斯分布中含有噪声。另一方面优化的过程中会趋向于使图像质量更好,因此当噪声为0的时候退化为普通的\(AutoEncoder\),这种情况我们是不希望出现的。为了平衡这种trade-off,这里希望每个\(p(Z|X)\)能够接近标准正态分布,但是另一方面网络又趋于使输入和输出图像更为接近,因此会使正态分布的方差向0的方向优化。经过这种对抗过程,最终就能产生具有一定可解释性的\(decoder\),同时最终得到的\(Z\)的分布也会趋向于\(\mathcal{N}(0,1)\),可以表示为:

​ $$p(Z)=\sum_{X} p(Z \mid X) p(X)=\sum_{X} \mathcal{N}(0, 1) p(X)=\mathcal{N}(0, I) \sum_{X} p(X)=\mathcal{N}(0, 1)$$

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def loss_function_original(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

​ 这里的loss由两部分组成,一部分是重建loss,一部分是使各个高斯分布趋近于标准高斯分布的loss(由KL散度推导得到)。

posted on 2021-11-01 19:49  何莫道  阅读(517)  评论(0编辑  收藏  举报