VAE-变分推断
1.推荐材料
1.PRML 第十章节 变分推断
2.B站 白板推导 这部分讲解的很详细
https://www.bilibili.com/video/BV1aE411o7qd?p=70
https://www.bilibili.com/video/BV1aE411o7qd?p=71
https://www.bilibili.com/video/BV1aE411o7qd?p=72
https://www.bilibili.com/video/BV1aE411o7qd?p=73
https://www.bilibili.com/video/BV1aE411o7qd?p=74
3.鲁鹏老师 计算机视觉与深度学习 这个讲的比较浅显易懂
https://www.bilibili.com/video/BV1V54y1B7K3?p=14
4.知乎 - 有不少笔误,需要详细斟酌
https://zhuanlan.zhihu.com/p/94638850
5.邱锡鹏 - 蒲公英书 第13.2章节 变分自编码器
6.https://blog.csdn.net/lrt366/article/details/83154048
这篇文章解释了两个我一直困惑的点
- 1.为什么\(先验P(Z)\)要假设为标准正态分布\(\mathcal{N}(0,1)\)
- 2.为什么有的VAE的代码中\(Z\)前面的拟合变量是\(\mu,\log \sigma^2,而不是\sigma\)
2.VAE的目标
对隐藏变量\(Z\)的特征提取(鲁鹏视频)
正常数据很难真的达到高斯分布,一般都是由多个高斯分布叠加组成,俗称GMM(混合高斯模型)
为了达到上面的目的需要完成下列两个事情
1.隐变量\(Z\)的真实概率分布\(P(Z)\)
2.求出生成概率模型\(P(X|Z)\)的参数\(\theta\)
\(问题1通过神经网络学习 输入X,输出P(Z),神经网络参数为\phi,这是VAE的前半段,称为推断网络\)
\(问题2通过神经网络血虚,输入Z,输出\hat X,神经网络参数为\theta,这是VAE的后半段,称为生成网络\)
模型结构
注意
- 0.这里为什么放了两张图?
- 1.首先要明白,第一张图是我们的最终目标,有的时候学习了VAE的结构,忘记了我们到底要做什么,第一张图就是时刻提醒自己
VAE的最终目的是什么(最终目的就上面两个小点) - 2.那第二张图是做什么?因为想要求真实\(P(Z)\)是很难的,为了实现这一目标,需要借助样本\(X\)学习到\(P(Z)\)
- 3.q是什么?下面逐步阐释
- 4.隐变量\(Z\)是一连续参数
- 5.很重要的一个点这里的编码器\(Z\)不是确定值,而是一个概率分布,Z在这个分布上滑动,输出不同的数据(看下鲁鹏老师的满月,半月案例),通常我们令\(P(Z)\)服从标准正态分布
3.流程梳理
1.梳理的第一步-主要网络结构
回到第一张图,我们先不看\(X\to Z\)这根虚线,\(Z\to X\)这条线是很明确的,就是为了实现2.2 这个小目标(邱锡鹏蒲公英书)
由贝叶斯公式得到\(P(X)=\frac{P(X|Z)P(Z)}{P(Z|X)}\)
\(\ln P(X)=\ln P(Z,X) - \ln P(Z|X) = \ln \frac{P(Z,X)}{q(Z)} - \ln \frac{P(Z|X)}{q(Z)} -\color{red}{公式1}\)
这里引入了一个新的分布\(q(Z)\),为什么要引入它?因为真实分布\(P(Z)\)无法获得,引入\(q(Z)\),希望\(q(Z)\)可以无限逼近\(P(Z)\)
那么这个\(q(Z)\)是怎么算出来的呢?这里就用到了VAE的前半段,也就是推断网络学习出来的
至此先梳理清楚VAE就是通过一个推断网络+生成网络组成的
另外一个很重要的点,前半段的生成网络\(q(Z)\)是通过学习样本\(X\)获得的,所以应该写成q(Z|X)的形式,并且我们令推断网络的参数为\(\phi\),生成网络的参数为\(\theta\),避免搞混
所以最终梳理一下标记
\(由样本X通过推断网络f_{I}(X,\phi)学习得到近似隐变量真实分布P(Z)的近似分布q(Z|X;\phi)\)
\(通过q(Z|X;\phi) \color{red}{采样}后得到一个样本数据Z,然后通过生成网络f_{G}(Z,\theta)学习得到X经过\color{red}{数据降维}后的数据\hat X,并且\hat X的数据分布满足P(X|Z;\theta)\)
\(采样:通过采样能够学习到所有数据概率值情况下的数据,比如输入一张满月的照片,一张半月(1/2个月亮),则通过采样,就可以获取1/2月亮-全月亮所有可能形状的月亮(鲁鹏视频),这个在图形处理中很有用,可以用于图像增强\)-这也解释了上面的第5点,为什么P(Z)是一个分布,需要噪声才能学习
\(数据降维:一般来说Z是比X低维的数据,这样就能对主要特征进行抽取\)
2.梳理的第二步-重要的假设
回到这张图
还有这张
假设1 - 可以看到通常我们假设这些隐变量都是服从正态分布的,也就是GMM模型,现在假设\(P(Z)\)服从标准正态分布\(\sim \mathcal{N}(0,1)\)(\color{red}{为什么这么假设?先放一放,后面再解释,或者直接看我推荐的材料6}),可以这么理解,\(Z\)必须有一定的噪声,这样后半段的生成网络才能在一定区间内生成数据,参考上面的满月,半月案例
3.梳理的第三步-推导目标函数
假设2 - 假设我们已知真实分布\(P(Z)\),这样我们先处理VAE的后半段-生成模型,那么为了求解生成模型中的参数\(\theta\),我们用最常用的最大似然法,通过贝叶斯方法求解边缘分布\(P(X)\),使得对数似然函数\(\log \prod P(X;\theta)\)最大即可
\(\ln P(X;\theta)=\ln P(Z,X;\theta) - \ln P(Z|X;\phi) = \ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)} -\color{red}{公式1}\)
大家可以看到这个公式无意中已经把前半段的推断网络也牵扯进来了,不再是单纯的求生成网络
还有一点,根据模型结构中的图1,其实也就是一个概率图模型,\(\theta\)同时决定了\(Z\)和\(X\),所以P(Z,X)添加了一个解释说明的参数,P(Z,X;\theta)代表P(Z,X)这个模型中的参数也是\(\theta\)
继续推导上面的公式
\(\ln P(X;\theta)=\ln P(Z,X;\theta) - \ln P(Z|X;\phi) = \ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}\)
\(两边对q(Z;\phi)求期望\)
\(左边=\int \ln P(X;\theta)q(Z|X;\phi)dZ=\ln P(X;\theta),不变\)
\(右边=\int[\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)} - \ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}]q(Z|X;\phi)dZ\)
\(=\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ - \int\ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(这个式子的前一半称为\color{red}{ELBO}(evidence\ lower\ bounds),后一半是KL散度,\color{red}{KL(q(Z|X;\phi) || P(Z|X;\theta))}\)
小结一下
\(\ln P(X;\theta) =\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ - \int\ln \frac{P(Z|X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(=ELBO + KL(q(Z|X;\phi) || P(Z|X;\theta))\)
- 结论1,若想要对数似然函数最大,也就是为了求解生成网络的参数\(\theta\),那么就是使得上式最大化,后面的KL散度没法求(不知道真实\(P(Z|X;\theta)\)),所以目标就是尽可能的\(\max ELBO\)
- 结论2,将上式改写为\(KL(q(Z|X;\phi) || P(Z|X;\theta)) = \ln P(X;\theta) -ELBO\),若想要\(q\)和\(P\)尽量相似,那么也就是推断网络的\(q\)使得KL尽量小,接近于0,故这个式子也代表着,为了求得推断网络的参数\(\phi\),要尽量使得\(\max ELBO\)
- 结论3,所以推断网络和生成网络的目标函数是一致,可以理解为整个VAE模型的目标函数就是\(\color{red}{\max ELBO}\)
4.梳理第四步-详解ELBO
ELBO = \(\int\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}q(Z|X;\phi)dZ\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln \frac{P(Z,X;\theta)}{q(Z|X,\phi)}]\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln \frac{P(Z|X;\theta)P(X;\theta)}{q(Z|X,\phi)}]\)
\(=\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)]-KL(q(Z|X;\phi) || P(Z|X;\theta))\)
先处理后一半
\(其中 P(Z|X;\theta) 是先验,之前我们假设了是标准正态分布\)
\(q(Z|X;\phi)是后验,我们仍然假设是正态分布(GMM模型),不过参数未知\sim \mathcal{N}(\mu_I,\sigma_I^2 I),下标I代表是从推断网络得到的\)
\(两个正态分布的KL散度可以直接求出,不推导了,直接看邱锡鹏老师的蒲公英书 公式13.24\)
\(KL(q(Z|X;\theta) || P(Z;\theta))=\frac{1}{2}(tr(\sigma^2_II)+\mu_I^T\mu_I-d-\log(|\sigma_I^2I|))\)
\(d是维度\)
5.梳理第四步-重参数化技巧
第四步中的前一半\(\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)]\)可以通过收集数据后取平均的方式,但是最大的问题是Z是通过采样得到的,没有确定的表达式,没办法求梯度
所以这里引入了重参数化技巧
重参数化
引入公式
\(Z=\mu_I +\sigma_I \times \epsilon\)-\(\color{red}{这里遗留了一个问题,为什么要用这个公式来表示?}\)
\(\epsilon \sim \mathcal{N}(0,I)\)
\(这样\mathbb{E}_{Z\sim q(Z|X;\phi)}[\ln P(Z|X;\theta)] 可以转化为 \mathbb{E}_{\epsilon\sim p(\epsilon}[\ln P(Z|g(\phi,\epsilon);\theta)]\)
这样我们就改写了网络结构
6.目标函数总结
\(目标函数最终定义为\)
\(L(\phi,\theta|X)=\sum_{n=1}^{N}(\frac{1}{M}\sum_{n=1}^{N}\log p(x^{(n)}|z^{(n,m)};\theta) - KL(q(z|x^{(n);\theta},N(z;0,I)))) - 蒲公英书 13.27\)
\(\color{red}{这一步的推导看起来很自然,但总感觉不知道怎么推导出来的}\)
\(\color{red}{另外还有一个点\mu_G书中介绍的是生产网络的输出,我看了一些代码的确也是这么写的,但我不太明白的是为什么用一个均值\mu_G的符号表示,且在蒲公英书的 13.18 公式中使用过\mu_G,明确写着这是一个均值符号,这有什么意义吗?}\)
\(L(\phi,\theta|X)=-\frac{1}{2}||x-\mu_G||^2 -\lambda KL(q(z|x^{(n);\theta},N(z;0,I)))) - 蒲公英书 13.27\)
\(今天重温了下概率论与数理统计,居然发现邱锡鹏老师的公式推导是完全正确的\)
\(有一个隐藏的等式要引入,因为在实际使用的正态分布的似然函数推导的时候用的是误差服从正态分布的推导,也就是\epsilon = x-\mu =x-\hat x,\epsilon\sim N(0,1),这样就完全串联起来了\)
4.其他变分方法
PRML书上介绍了除了本章节要讲的方法,变分方法包括有限元方法,最大熵方法
除了基于梯度的变分方法,PRML书中还着重说明了基于平均场理论的变分方法
5.其他
为什么要求先验\(P(Z)是标准正态分布?\)
https://blog.csdn.net/lrt366/article/details/83154048
1.VAE模型的输出是需要带有噪声的,否则就完全退化为AE了,但是添加了噪声后,经过学习的输出为了让损失函数减小必然会将方差缩小到0,那么谈何添加噪声呢?
2.为了保证第一点中保证有噪声,就必须让输出服从一定的概率分布,也就是\(P(Z|X;\phi)\)服从标准正态分布,\color{red}{只有让Z具有一定噪声,才能让后面的\hat X输出能拟合出所有分布的数据,参考上面的满月,半月的案例}
3.基于2的结论,再反推回\(P(Z)\)也是标准正态分布
4.损失函数KL散度的计算,就是为了让\(P(Z|X;\phi)\)能够尽量靠拢标准正态分布\(P(Z)\)