Auto-Encoding Variational Bayes 公式推导及代码

变分自动编码器(VAE)用于生成模型,结合了深度模型以及静态推理。简单来说就是通过映射学习将一个高维数据,例如一幅图片映射到低维空间Z。与标准自动编码器不同的是,X和Z是随机变量。所以可以这么理解,尝试从P(X|Z)中去采样出x,所以利用这个可以生成人脸,数字以及语句的生成。

1.模型


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
graph LR A[Data] -->|DNN_1| B(mu,std) B --> C[z] D(eps) --> C C-->|DNN_2|E[gen_data]

以上为模型的代码和图示,我们构建的为DNN_1和DNN_2模型,Data-->z为编码器,z-->gen_data为解码器。

2.损失函数

2.1 L

假设x-->z的真实概率分布为p(z|x),我们的模型的概率分布为q(z|x),那么损失可用他们的KL散度表示【\(\sum_z q(z|x) \log \frac{q(z|x)}{p(z|x)}\)】恒大于0,其值越大表示相似度越低,即损失越大

\[ \sum_z q(z|x) \log \frac{q(z|x)}{p(z|x)} =\sum_z q(z|x) \log (p(x)\frac{q(z|x)}{p(z|x)*p(x)}) =\sum_z q(z|x) \log (p(x)\frac{q(z|x)}{p(z,x)}) =\log p(x)+\sum_z q(z|x) \log (\frac{q(z|x)}{p(z,x)}) \]

由于p(x)为真实分布,是个固定分布,故最小化KL(q(z|x)||p(z|x))即最大化\(\sum_z q(z|x) \log (\frac{p(z,x)}{q(z|x)})\),我们设其为L

\[L=\sum_z q(z|x) \log (\frac{p(z)*p(x|z)}{q(z|x)})=\sum_z q(z|x) \log (\frac{p(z)}{q(z|x)})+\sum_z q(z|x) \log (p(x|z))=L_1+L_2 \]

我们假设z的先验概率p(z)是N(0,1)分布,而我们的模型学到的q(z|x)是N(mu,std)分布

2.1.1 L1

\[L_1=E_{z\sim N(\mu,\sigma^2)} \log \frac{\frac{1}{\sqrt{2\pi}}e^{\frac{-z^2}{2}}}{\frac{1}{\sqrt{2\pi}\sigma}e^{\frac{(z-\mu)^2}{2\sigma^2}}}=E_{z\sim N(\mu,\sigma^2)} (\log \sigma-\frac{z^2}{2}+\frac{(z-\mu)^2}{2\sigma^2})= E_{z\sim N(\mu,\sigma^2)} (\log \sigma-\frac{z^2}{2}+\frac{z^2+\mu^2-2\mu z}{2\sigma^2}) \]

对于正态分布\(z\sim N(\mu,\sigma)\),\(E(z)=\mu,E(z^2)=D(z)+E(z)^2=\mu^2+\sigma^2\)
因此

\[L_1=\log \sigma -\frac{\mu^2+\sigma^2}{2}+\frac{\mu^2+\sigma^2+\mu^2-2\mu\mu}{2\sigma^2}=\log \sigma +\frac{1}{2}(1-\mu^2-\sigma^2) \]

2.1.2 L2

如果直接从 \(N(\mu,\sigma)\) 中采样,那么采样的结果是不可导的,我们通过重参数技巧来解决不能梯度下降的问题,即采样$ \epsilon \sim N(0,1) $,用 $ \epsilon * \sigma +\mu $ 来代替 $ z \sim N(\mu,\sigma)$。用蒙特卡洛方法来计算L2。
这个是对应解码器部分【交叉熵】

2.2 代码

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD
posted @ 2019-10-17 14:57  benda  阅读(1418)  评论(0编辑  收藏  举报