变分推断

变分推断的基本形式

变分推断是使\(q(z)\)逼近\(p(z\vert x)\)来求得隐变量\(z\)的后验分布\(p(z\vert x)\)。根据贝叶斯公式,有

\[\begin{align*} \underbrace{\log\left(p(x)\right)}_{\text{evidence}} &= \log\left(p\left(x, z\right)\right)-\log\left(p(z\vert x)\right)\\ &= \underbrace{\int_z q(z)\log\left(\frac{p(x, z)}{q(z)}\right)}_{\text{evidence low bound}}-\underbrace{\int_z q(z)\log\left(\frac{p(z\vert x)}{q(z)}\right)}_{\text{KL divergence}}\end{align*} \]

\(\log\left(p(x)\right)\)被称为Evidence的原因是因为它是来自我们观察到的,又因为KL-divergence不为负,为了使得\(q(z)\)逼近\(p(z\vert x)\),优化的目标就是上面的ELOB

Mean field

中场理论(Mean field)一般假设

\[q(z)=\prod_{i=1}^M q(z_i) \label{eq:mean field} \tag{1} \]

代入\(~\ref{eq:mean field}\)得到

\[\int_z \prod_{i=1}^M q(z_i) \log\left(p(x,z)\right)\,dz-\\ \int_{z_1}q(z_i)\int_{z_2}q(z_2)\cdots\int_{z_M}q(z_M) \sum_{i=1}^M\log\left(q(z_i)\right)\,dz_M\cdots dz_1 \]

\[\begin{align*}\int_z \prod_{i=1}^M q(z_i) \log\left(p(x,z)\right)\,dz-\sum_{i=1}^M\int_{z_1}q(z_1)\log\left(q(z_1)\right)\end{align*} \]

\[\log\left(\tilde{p}_j(x, z)\right)=E_{i\neq j}\left[\log\left(p(x,z_j)\right)\right] \]

针对第\(z_j\)ELOW

\[\int_{z_j} q(z_j) \log\left(\tilde{p}_j(x,)\right) - \int_{z_j} q(z_j) \log\left(q(z_j)\right) \]

因此当\(q(z_j)=\tilde{q}_j(x,z)\)时上式取得最小值\(0\)。因此通过迭代\(z_j\)可以求得逼近\(p(z\vert x)\)\(q(z)\)

指数函数变分推断例子

假设\(p(x),\,p(x\vert z)\)都来自某指数族分布,指数族分布形式如下

\[p(x\vert \eta)=h(x)\exp\left(\eta^TT(x)-A(\eta)\right) \]

且满足

\[\begin{align*} A'(\eta_{MLE}) &= \frac{1}{n}\sum_{i=1}^n T(x_i)\\ A'(\eta) &= E_{p(x\vert \eta)}\left[T(x)\right]\\ A''(\eta) &= Var\left[T(x)\right]\end{align*} \]

假设隐变量\(z\)可以分为两部分\(Z\)\(\beta\),那么ELOB可以写为

\[\int_{Z,\beta} q(Z,\beta)\log\left(p(x,Z,\beta)\right)-\int_{Z,\beta}q(Z,\beta)\log\left(q(Z,\beta)\right) \]

根据指数族的性质后验分布\(p(\beta\vert Z,x)\)\(p(Z\vert \beta, z)\)都属于指数族

\[\begin{align*}p(\beta\vert Z,x) &= h(\beta) \exp\left(T(\beta)^T\eta(Z,x)-A\left(\eta(Z,x)\right)\right)\\p(Z\vert \beta,x) &= h(Z) \exp\left(T(Z)^T\eta(\beta,x)-A\left(\eta(\beta,x)\right)\right)\end{align*} \]

这里只展示\(p(\beta\vert Z,x)\)的近似分布\(q(\beta\vert \lambda)\)求解,对于\(p(Z\vert \beta, x)\)的近似分布\(q(Z\vert \phi)\)也类似

\[q(\beta\vert \lambda)=h(\lambda)\exp\left(T(\beta)^T\lambda-A(\lambda)\right) \]

根据上一节的结果,ELOB是关于\(\lambda,\, \phi\)的函数

\[E_{q(Z,\beta)}\left[\log\left(p(\beta\vert Z,x)\right)\log\left(p( Z\vert x)\right)\log\left(p(x)\right)\right]-\\E_{q(Z,\beta)}\left[\log\left(q(Z)\right)\log\left(q(\beta)\right)\right] \]

固定\(\phi\),上式中与\(\lambda\)有关的项为

\[E_{q(Z,\beta)}\left[\log\left(q(\beta\vert Z,x)\right)\right]-E_{q(Z,\beta)}\left[\log\left(q(\beta)\right)\right] \]

\(\log\left(q(\beta\vert Z,x)\right)\)\(\log\left(q(\beta)\right)\)定义带入,得到与\(\lambda\)有关的项为

\[E_{q(\beta)}\left[T(\beta)\right]^TE_{q(Z)}\left[\eta(Z,x)\right]-\lambda^TE_{q(\beta)}\left[T(\beta)\right]+A(\lambda) \]

利用\(A'(\eta)=E_{p(x\vert \eta)}\left[T(x)\right]\)得到

\[\begin{align*}L(\lambda,\phi) &= A'(\lambda)^TE_{q(Z)}\left[\eta(Z,x)\right]-\lambda^T A'(\lambda) + A(\lambda)\\\frac{\partial L(\lambda, \phi)}{\partial \lambda}&=A''(\lambda)^TE_{q(Z)}\left[\eta(Z,x)\right]-A'(\lambda) - \lambda^T A''(\lambda) + A'(\lambda)\end{align*} \]

因为\(A''(\lambda)\neq0\),因此

\[\lambda=E_{q(Z\vert \phi)}\left[\eta(Z, x)\right] \]

同理

\[\phi = E_{q(\beta\vert \lambda)}\left[\eta(\beta, x)\right] \]

随机梯度变分推断

不同于mean field,随机梯度变分推断将分布\(q(z\vert \phi)\)看为关于\(\phi\)的分布,通过对\(\phi\)进行优化得到最优的分布。

\[\begin{align*}\nabla_{\phi}L&=\nabla_{\phi}E_{q(z\vert \phi)}\left[\log\left(p(x,z)\right)-\log\left(q(z\vert\phi)\right)\right]\\&=E_{q(z\vert \phi)}\left[\nabla_{\phi}\left[\log q(z\vert \phi\right]\left(\log p(x,z)-\log q(z\vert \phi)\right)\right]\end{align*} \]

随后用蒙塔卡罗就可以近似出梯度,虽然直接使用蒙特卡洛会造成方差较大,可以通过重参数技巧进行减小方差(在VAE中也有用到)。重参数后的计算参见SGVI。

参考

  1. WallE-Chang SGVI repository

  2. ws13685555932 machine learning derivative repository

  3. shuhuai008 SGVI

posted @ 2020-04-25 00:00  Neo_DH  阅读(433)  评论(0编辑  收藏  举报