深度变分信息瓶颈——Deep Variational Information Bottleneck

3020编辑收藏

  Deep Variational Information Bottleneck (VIB) 变分信息瓶颈 论文阅读笔记。本文利用变分推断将信息瓶颈框架适应到深度学习模型中,可视为一种正则化方法。

1  变分信息瓶颈#

  假设数据输入输出对为(X,Y),假设判别模型fθ()有关于X的中间表示Z,本文旨在优化θ以最小化互信息I(Z;X) ,同时最大化互信息I(Z;Y),即:

maxθI(Z;Y|θ)βI(Z;X|θ)

  其中β>0为平衡系数。直觉理解,上式期望Z能保留更少X信息的同时能较好用于预测Y。那么如何构造相应的深度学习模型以及相应的优化方案?下面推导上式的下界,使其下界变大,上式即可变大。为了简化,下面去掉θ进行推导。

1.1  上界1#

  I(Z;X)展开为:

I(Z;X)=p(x,z)logp(z|x)p(z)dxdz

  其中p(z|x)为是原始模型关于x对中间表示z的推理分布。对于其中的p(z),作者用另一个变分估计r(z)来拟合。由于有

KL(p(Z),r(Z))0p(z)logp(z)dzp(z)logr(z)dzp(x,z)logp(z)dxdzp(x,z)logr(z)dxdz

  则有

I(Z;X)=p(x,z)logp(z|x)p(x,z)logp(z)dxdzp(x,z)logp(z|x)p(x,z)logr(z)dxdz=p(x)p(z|x)logp(z|x)r(z)dxdz

1.2  下界2#

  I(Z;Y)展开为:

I(Z;Y)=p(y,z)logp(y|z)p(y)dydz

  其中p(y)是数据的标签分布,已知。未知而需要进行处理的是其中的p(y,z)p(y|z),也就是模型需要拟合的分布。对于p(y|z),可以用一个解码器q(y|z)来拟合,即文中所谓的变分估计。利用KL散度的大于零性质,有以下不等式:

KL(p(Y|Z),q(Y|Z))0p(y|z)logp(y|z)q(y|z)dy0p(y,z)p(z)logp(y|z)q(y|z)dy0p(y,z)logp(y|z)dyp(y,z)logq(y|z)dy

  注意最后一步去掉p(z)是由于它没有在积分中,是常数。则有

I(Z;Y)=p(y,z)logp(y|z)p(y,z)logp(y)dydzp(y,z)logq(y|z)p(y,z)logp(y)dydz=p(y,z)logq(y|z)dydzp(y)logp(y)dy=p(y,z)logq(y|z)dydz+H(Y)

  对于其中的p(y,z),本文基于马尔科夫假设:YXZ。这个假设表明,YZX的条件下独立(那在优化时呢?Z是关于XY的联合分布进行更新的)。有:

p(y,z)=p(x,y,z)dx=p(x,y)p(z|x,y)dx=p(x,y)p(z|x)dx

  此外,由于H(Y)已知且固定,可忽略。则有

I(Z;Y)p(x,y)p(z|x)logq(y|z)dxdydz

  其中,p(x,y)是真实数据分布,p(z|x)是原始模型关于x对中间表示z的推理分布。

1.3  总体下界和优化#

  结合下界1和上界2,有:

I(Z;Y)βI(Z;X)p(x,y)p(z|x)logq(y|z)dxdydzβp(x)p(z|x)logp(z|x)r(z)dxdz=L

  针对上式,用经验分布来代替真实分布。即用1Nn=1Nδxn(x)代替p(x),用1Nn=1Nδyn(y)代替p(y),用1Nn=1Nδ(xn,yn)(x,y)代替p(x,y)。其中δxn(x)表示狄拉克函数,其空间内积分为1,且仅在xn上非零。假设经验分布有N各样本{(xn,yn)}n=1N。实际上直接把概率积分改成离散样本的求和取平均即可。则上式可被估计为:

L1Nn=1Np(z|xn)logq(yn|z)βp(z|xn)logp(z|xn)r(z)dz

  文中将z视为隐变量,利用VAE的重参数技巧将p(z|xn)实现为一个关于xn的正态分布N(feμ(xn),feΣ(xn)),其中feμ(xn),feΣ(xn)分别为基于xn生成的均值和协方差矩阵。将z抽样表示为f(xn,ϵ)=feΣ(xn))ϵ+feμ(xn),其中ϵN(0,1)。则最大化L可表示为最小化:

JIB=1Nn=1NEϵN(0,1)[logq(yn|f(xn,ϵ))]+βKL[p(Z|xn);r(z)]

  其中r(z)利用某一特定分布实现,文中使用标准正态分布实现。

2  直觉理解#

  直觉上理解:模型要把每个xn分别映射到特定的分布,这些分布既不能偏离标准正态分布太远,又需要让模型后续能根据这些分布的抽样来预测xn的标签。那么这种做法为什么能从xn中抽取对预测yn有效的关键信息而忽略无关信息呢(即信息瓶颈)?我的理解是,模型被惩罚以使不同xn得到的zn分布靠近同一分布,但为了有效预测yn,又必须产生一定的不一致。不同xn对应的z分布越一致,通过z而流向y的差异性信息将越少,导致q更难利用采样的z预测y,从而促使模型忽略x中的冗余信息而保留预测y所需的关键信息。β则用于控制z保留x信息的程度,越大保留信息越少。

  相较于一般的判别模型:当不把z视为隐变量,而变成关于x唯一确定的中间表示时,就是一般的判别模型。这种方式隐式地假定了表示的连续性,然而无法确保所有z都不是被离散地分散在表示空间中。最坏的过拟合情况下,每个(xn,yn)都孤立地确定了一个中间表示zn来实现一一映射,导致无泛化。而对于使用了信息瓶颈z的判别模型,由于x仅仅确定z的生成分布,不同的xi,xj依然可能抽样出同一个z,这种模式强制这个抽样出的z必须共享这两个样本的相似信息并忽略不同的信息,从而表示语义的相似性被强制由线性距离控制,实现表示语义的连续性,从而显式地确定了模型的泛化。

3  实验#

  表1:信息瓶颈加成的模型和各种正则化后模型的对比。

  图1:不同βz维度K下VIB模型在MNIST上的错误率,以及两个互信息的平衡。

  图2:z维度K=2时,1000张图片的z分布的可视化。

  后续是一些对抗鲁棒的实验,不记录

相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
很高兴能帮到你~
点赞
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示