变分推断的基本形式
变分推断是使\(q(z)\)逼近\(p(z\vert x)\)来求得隐变量\(z\)的后验分布\(p(z\vert x)\)。根据贝叶斯公式,有
\[\begin{align*} \underbrace{\log\left(p(x)\right)}_{\text{evidence}} &= \log\left(p\left(x, z\right)\right)-\log\left(p(z\vert x)\right)\\ &= \underbrace{\int_z q(z)\log\left(\frac{p(x, z)}{q(z)}\right)}_{\text{evidence low bound}}-\underbrace{\int_z q(z)\log\left(\frac{p(z\vert x)}{q(z)}\right)}_{\text{KL divergence}}\end{align*}
\]
\(\log\left(p(x)\right)\)被称为Evidence
的原因是因为它是来自我们观察到的,又因为KL-divergence
不为负,为了使得\(q(z)\)逼近\(p(z\vert x)\),优化的目标就是上面的ELOB
。
Mean field
中场理论(Mean field)一般假设
\[q(z)=\prod_{i=1}^M q(z_i) \label{eq:mean field} \tag{1}
\]
代入\(~\ref{eq:mean field}\)得到
\[\int_z \prod_{i=1}^M q(z_i) \log\left(p(x,z)\right)\,dz-\\ \int_{z_1}q(z_i)\int_{z_2}q(z_2)\cdots\int_{z_M}q(z_M) \sum_{i=1}^M\log\left(q(z_i)\right)\,dz_M\cdots dz_1
\]
即
\[\begin{align*}\int_z \prod_{i=1}^M q(z_i) \log\left(p(x,z)\right)\,dz-\sum_{i=1}^M\int_{z_1}q(z_1)\log\left(q(z_1)\right)\end{align*}
\]
令
\[\log\left(\tilde{p}_j(x, z)\right)=E_{i\neq j}\left[\log\left(p(x,z_j)\right)\right]
\]
针对第\(z_j\),ELOW
为
\[\int_{z_j} q(z_j) \log\left(\tilde{p}_j(x,)\right) - \int_{z_j} q(z_j) \log\left(q(z_j)\right)
\]
因此当\(q(z_j)=\tilde{q}_j(x,z)\)时上式取得最小值\(0\)。因此通过迭代\(z_j\)可以求得逼近\(p(z\vert x)\)的\(q(z)\)
指数函数变分推断例子
假设\(p(x),\,p(x\vert z)\)都来自某指数族分布,指数族分布形式如下
\[p(x\vert \eta)=h(x)\exp\left(\eta^TT(x)-A(\eta)\right)
\]
且满足
\[\begin{align*} A'(\eta_{MLE}) &= \frac{1}{n}\sum_{i=1}^n T(x_i)\\ A'(\eta) &= E_{p(x\vert \eta)}\left[T(x)\right]\\ A''(\eta) &= Var\left[T(x)\right]\end{align*}
\]
假设隐变量\(z\)可以分为两部分\(Z\)和\(\beta\),那么ELOB
可以写为
\[\int_{Z,\beta} q(Z,\beta)\log\left(p(x,Z,\beta)\right)-\int_{Z,\beta}q(Z,\beta)\log\left(q(Z,\beta)\right)
\]
根据指数族的性质后验分布\(p(\beta\vert Z,x)\)和\(p(Z\vert \beta, z)\)都属于指数族
\[\begin{align*}p(\beta\vert Z,x) &= h(\beta) \exp\left(T(\beta)^T\eta(Z,x)-A\left(\eta(Z,x)\right)\right)\\p(Z\vert \beta,x) &= h(Z) \exp\left(T(Z)^T\eta(\beta,x)-A\left(\eta(\beta,x)\right)\right)\end{align*}
\]
这里只展示\(p(\beta\vert Z,x)\)的近似分布\(q(\beta\vert \lambda)\)求解,对于\(p(Z\vert \beta, x)\)的近似分布\(q(Z\vert \phi)\)也类似
\[q(\beta\vert \lambda)=h(\lambda)\exp\left(T(\beta)^T\lambda-A(\lambda)\right)
\]
根据上一节的结果,ELOB
是关于\(\lambda,\, \phi\)的函数
\[E_{q(Z,\beta)}\left[\log\left(p(\beta\vert Z,x)\right)\log\left(p( Z\vert x)\right)\log\left(p(x)\right)\right]-\\E_{q(Z,\beta)}\left[\log\left(q(Z)\right)\log\left(q(\beta)\right)\right]
\]
固定\(\phi\),上式中与\(\lambda\)有关的项为
\[E_{q(Z,\beta)}\left[\log\left(q(\beta\vert Z,x)\right)\right]-E_{q(Z,\beta)}\left[\log\left(q(\beta)\right)\right]
\]
将\(\log\left(q(\beta\vert Z,x)\right)\)和\(\log\left(q(\beta)\right)\)定义带入,得到与\(\lambda\)有关的项为
\[E_{q(\beta)}\left[T(\beta)\right]^TE_{q(Z)}\left[\eta(Z,x)\right]-\lambda^TE_{q(\beta)}\left[T(\beta)\right]+A(\lambda)
\]
利用\(A'(\eta)=E_{p(x\vert \eta)}\left[T(x)\right]\)得到
\[\begin{align*}L(\lambda,\phi) &= A'(\lambda)^TE_{q(Z)}\left[\eta(Z,x)\right]-\lambda^T A'(\lambda) + A(\lambda)\\\frac{\partial L(\lambda, \phi)}{\partial \lambda}&=A''(\lambda)^TE_{q(Z)}\left[\eta(Z,x)\right]-A'(\lambda) - \lambda^T A''(\lambda) + A'(\lambda)\end{align*}
\]
因为\(A''(\lambda)\neq0\),因此
\[\lambda=E_{q(Z\vert \phi)}\left[\eta(Z, x)\right]
\]
同理
\[\phi = E_{q(\beta\vert \lambda)}\left[\eta(\beta, x)\right]
\]
随机梯度变分推断
不同于mean field,随机梯度变分推断将分布\(q(z\vert \phi)\)看为关于\(\phi\)的分布,通过对\(\phi\)进行优化得到最优的分布。
\[\begin{align*}\nabla_{\phi}L&=\nabla_{\phi}E_{q(z\vert \phi)}\left[\log\left(p(x,z)\right)-\log\left(q(z\vert\phi)\right)\right]\\&=E_{q(z\vert \phi)}\left[\nabla_{\phi}\left[\log q(z\vert \phi\right]\left(\log p(x,z)-\log q(z\vert \phi)\right)\right]\end{align*}
\]
随后用蒙塔卡罗就可以近似出梯度,虽然直接使用蒙特卡洛会造成方差较大,可以通过重参数技巧进行减小方差(在VAE中也有用到)。重参数后的计算参见SGVI。
参考
-
WallE-Chang SGVI repository
-
ws13685555932 machine learning derivative repository
-
shuhuai008 SGVI