论文信息
论文标题:Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism
论文作者:Siqi Miao, Mia Liu, Pan Li
论文来源:2022,ICML
论文地址:download
论文代码:download
1 Introduction
1.1 引入
GNN的可解释性问题:通常旨在从原始的输入图中提取一个子图:人们希望提取的子图中仅包含最能帮助标签预测的信息。
例子:如下图,我们知道-OH
官能团能够使得一个分子具有水溶性。因此对于一个用来预测分子水溶性的GNN来说,人们希望给定下图的分子后,模型能够告诉我们对预测最重要的部分是-OH
官能团所代表的子图。这样一来,人们就能从模型中获取更多的关于数据关键特征的理解,从而指导进一步的研究。
GNN可解释性主要有两大类方法:
- 固有的可解释模型(Inherently Interpretable Models);
- 事后解释方法(Post-hoc Interpretation Methods);
第一类方法主要旨在设计自身即可提供解释性的GNN模型。这些模型的良好可解释性往往以其预测精度为代价,如注意力机制,多篇研究显示其无法为GNN带来值得信任的可解释结果;
第二类方法,这些工作通常假设人们会提供一个预先训练好的GNN。随后它们会将该GNN的参数固定,然后训练一个新的模型,即解释器(Explainer),来从输入图中找出一个子图。它们希望这些子图能够:1)尽可能小;2)尽可能保持原有预测分数。最后这些子图即被认为是GNN捕捉到的数据的关键特征。
最近,基于不变因果特征学习(Invariant Learning)的工作也逐渐被提出。这些工作认为训练数据中可能会存在数据偏见(Data Bias),使得模型最终学习到一些和标签具有伪相关性(Spurious Correlations)的特征,并因此遭受严重的泛化问题。下图展示了伪相关特征的一个例子。这些特征可能是收集或生成训练数据时的偏见造成的,它们实质上并不是真正决定样本类别的特征。而当测试集不存在这些伪相关的特征时,模型的效果将大打折扣。因此,这些工作引入了因果分析理论(Causality Analysis),希望迫使模型学习数据中不变的、与标签具有因果关系的特征(Invariant Causal Patterns),来解决上述 OOD 泛化问题(Out-of-distribution Generalization)。这类方法在寻找那些不变的因果特征时,也能提供一定程度的自身可解释性。不过也由于这些方法引入了因果分析,它们的架构往往十分复杂且需要大量的计算。
而在这篇工作中,作者们指出了事后解释方法的诸多问题,并同样专注于设计自身可解释模型。这篇工作提出了一种全新的随机注意力机制(Stochastic Attention Mechanism),该机制显示出了强大的可解释能力和泛化能力。对比过去的可解释工作,该机制在6个数据集上提升了至多20%、平均12%的可解释性能;在11个数据集上提升了平均3%的模型准确率,并且在OGBG-MolHiv榜单上达到SOTA(在不使用手工设计的专家特征的模型中)。
除此之外,该机制对可解释能力和泛化能力的提升同样具有理论保障。在一定假设下,该机制天然的不受伪相关特征的影响,从而能够抓取出真正重要的数据特征。在去除伪相关特征的能力上,该机制以远远更小的复杂度,对比基于因果分析的方法提升了平均12%的OOD泛化能力。
1.2 事后解释方法的问题
作者们在文中首先指出了事后可解释方法的四个问题,并认为这些事后解释方法擅长于检查预先训练好的模型对一些特征的敏感程度,但它们并不能提取出对预测真正重要的数据特征,而这才应该是可解释方法需要解决的最有趣的问题。具体来说,作者们指出的四个问题是:
1. 数据分布偏移(Data Distribution Shifts)
首先,事后解释方法将不可避免的遭受数据分布偏移的影响。直觉上,这是因为给定的预先训练好的模型(记作),总是在原始输入图上进行训练的:它从来没有在任何子图上进行过训练。因而极有可能在上是欠拟合的,故而导致并不能真正反应各个子图的重要性。
2. 与标签伪相关的数据特征(Spuriously Correlated Patterns)
其次,预先训练得到的模型可能会过拟合训练数据中与标签信息伪相关,甚至是无关的特征。这是由于大多数模型本身是基于最大化互信息法则(Maximum Mutual Information Principle)来进行训练的,因此在训练中自然会捕捉尽可能多的输入特征,而这也是不变因果特征学习这个方向产生的主要动机。
在这种情况下,事后解释方法很可能会将这些伪相关或者无关的特征提取出来,当作数据中的关键特征,而这可能会将人们引入到一个错误的方向。【自编码器预训练 接 线性分类器.....】
3. 初始化问题(Initialization Issuses)
随后,作者们从优化和信息瓶颈理论切入,指出事后解释方法对不同的的初始化是敏感的。在同一个数据集上,基于不同的随机种子训练得到的,事后解释方法可能会得出差异较大解释结果。而过去的事后解释方法,在评估时往往会忽略这一点,只基于某一个固定的,仅在不同的随机种子上训练解释器。这可能会得到过于乐观的结果,而使得事后解释方法的性能没有得到全面的评估。【使用不同随机种子训练的,我就看到一个论文这样用......】
4. 潜在的有偏见的约束(Potentially Biased Constraints)
最后,由于上述各种问题,事后解释方法有时很难得出符合人们直觉的解释子图。故这些方法中往往嵌入稀疏化约束(Sparsity Constraint),或连接性约束(Connectivity Constraint)等,来得到人们更能理解的数据特征。这些约束极大的要求人们对数据集和任务自身具有一定的先验知识,否则这些约束很可能极大的影响模型的解释结果。一个优秀的可解释模型应当自身即能够抓取适当的数据关键特征而不用附加其它约束。本文提出的随机注意力机制能够在没有上述约束的情况下,取得远远更好的可解释性能。【图结构学习吧,很重要的一点就是加图正则化.......】
1.3 Preliminaries
指:子图和图标签相关。如:为确定分子的溶解度,羟基-OH 是一个正相关的子图,就好像它存在,分子通常溶于水。寻找与标签相关的子图是可解释图学习的一个共同目标。
具有注意力的 GNN 经常会产生低保真度的注意力权重,因为它为每条边学习多个权值,因此将这些权值与不规则的图结构结合起来,进行与图标签相关的特征选择是有问题的。
两种类型的注意力模型:
-
- 注意力权重归一化,求和为 1 ,如 GAT;
- 在没有归一化的情况下学习 之间的权重,如 GGNN;
Learning to Explain (L2X)
L2X [ 2018 ] 研究了正则特征空间中的特征选择问题,并提出了一种互信息(MI)最大化规则来选择固定数量的特征。
具体来说,让 表示两个随机变量 和 之间的 MI。较大的 MI 表明两个随机变量之间存在一定的高相关性。因此,根据输入特征 ,L2X 是搜索一组大小为 的索引 ,其中 。使由索引 (用 表示) 的子空间中的特征与标签 的互信息最大化,即,
本文模型的灵感来自于 L2X ,由于 图特征 及 可解释的特征 并没有固定的维数,所以不能直接应用 L2X ,因此,本文建议使用 Sec 3.1 中使用信息约束。
熵定义为 :
KL 散度定义为:
信息瓶颈 IB :
3 Graph Learning Interpretation via GIB
3.1 GIB-based Objective for Interpretation
其中 表示 的子图的集合。
注意,GIB 不施加任何潜在的偏差约束,如所选子图的大小或连通性。相反,GIB 使用信息约束 来选择只从 中继承了最具指示性的信息的 ,通过最大化 来预测标签 。因此,GS提供了模型解释。
3.2 Issues of Post-hoc GNN Interpretation Methods
几乎所有之前的方法都是事后方法,如 GNNExplainer [2019 ],PGExplainer [2020] 和 GraphMask [2021 ]。给定一个预先训练的预测器 ,他们试图找出影响模型预测最大的子图 ,同时保持预先训练的模型不变。这个过程首先最大化了 和 之间的 MI,并得到模型参数:
然后优化子图提取器 :
其中, 表示满足一些约束的子图子集。例如,GNNExcrener 和 PGExplainer 采用的基数约束。让我们暂时忽略不同约束之间的差异,只关注优化目标。
事后方法的目标函数 和 GIB()有些类似。然而,事后方法可能不能给出,甚至不能近似于 的最优解,因为 并不是经过联合训练的。从优化的角度来看,事后方法只从无约束空间的模型 到信息约束空间 的单步投影到 (见 Fig. 4),投影规则遵循诱导的 MI 降低最小化 的 。
后果:
- 首先, 可能不能完全从 中提取信息来预测 ,因为 最初被训练使 近似 ,而 遵循不同的 的分布。因此, 可能不能很好地近似 ,从而可能误导 的优化,并使 不能选择真实表示 的 。
- 其次,对 的积极优化可能会给出一个大的 MI : (或者 相当小的训练损失)通过选择有助于区分训练标签的特征,但本质上与标签无关或与总体水平上的标签虚假相关。即:无用的特征有时也能很好的区分图标签。根据经验,我们确实观察到 Mutag 上的过拟合问题,如 Fig.5 所示,特别是 PGExplainer 和 GraphMask。在前 5 到 10 个epoch,这两种模型成功地选择了良好的解释,同时有很大的训练损失。进一步的训练成功地减少了损失(在 10 个 epoch 之后),但却大大降低了解释性能。这也可能是为什么在这些事后方法的原始文献中,只建议在少数epoch上进行训练的原因。然而,在实际任务中,很难有地面真实解释标签来验证结果,并确定一个可靠的停止标准。
事后方法存在初始化问题。它们的可解释性对预先训练好的模型 高度敏感,Fig.5 中的大方差证明了这一点。只有当预训练的 近似于最优 时,才能大致保证其性能。因此,根据 GIB 原理等式对 进行联合训练通常需要根据 的 GIB 准则。
4 Stochastic Attention Mechanism for GIB
在本节中,首先给出 GIB 目标的变分界(),然后引入带有随机注意机制的模型GSAT。
4.1 A Tractable Objective for GIB
GSAT 是学习一个参数 的子图提取器来提取 。 通过注入的随机性阻断数据 中与标签无关的信息,同时允许 中保存的与标签相关的信息做出预测。
在 GSAT中, 本质是 上的分布,将该分布表示为 。 即为 。
对于一些 ,通过 GIB 得到了 的优化:
接下来,我们推导出 中这两项的一个可处理的变分上界。详细的推导在附录 B 中给出。
对于项 ,引入针对 的参数化变分近似 ,得到一个下界:
在本文模型, 本质上可以作为参数 的预测器 。
对于项 ,引入边际分布 的变分近似 。并得到上界:
插入上述两个不等式,得到了 的一个变分上界作为 GSAT 的目标:
接下来,在 GSAT 中指定 (又名 )、(又名 )和 。
4.2 GSAT and Stochastic Attention Mechanism
Stochastic Attention via
提取器 首先通过GNN将输入图 编码为一组节点表示形式 。对于每条边 , 包含一个 MLP 层后接 ,它将边连接 映射到 。然后,对于训练的每一次向前传递,从伯努利分布 中抽取随机注意。为确保梯度 是可计算的,应用 gumbel-softmax 重新参数化技巧。提取的图 将有一个被注意选择的子图为 。这里 是 项为 的矩阵,或者非边项为零。上述程序给出 的 的分布,描述了 ,所以 ,其中 是 的一个函数,这本质上使得注意 在给定输入图 的不同边上是有条件地独立的。
Prediction via
预测器 采用相同的 GNN 将提取的图 编码为图表示,最后将该表示通过 MLP 层加 softmax 对 的分布进行建模。该过程给出了变分分布 。
Marginal Distribution Control via
对于任何 都总是正确的。我们对 的定义如下。
对于 中的每个图 和 中的每两个有向节点对 ,我们采样 ,其中 是一个超参数。如果 ,我们删除 中的所有边,并添加所有边 。假设得到的图为 。此过程定义了分布 。由于 独立于图 ,假设其大小为 ,。一个大小为 的图的概率 是一个常数,因此不会影响模型。请注意,我们对 的选择具有使用标准高斯分布作为变分自动编码器的潜在分布的相似精神。
使用上面的 , 中的第一项减少为一个标准的交叉熵损失。使用 和 ,对于每一个 , 为 的大小,kl-散度项变为
其中 是一个没有任何可训练参数的常数。
4.3 The Interpretation Mechanism of GSAT
GSAT 的可解释性本质上来自于信息控制:GSAT通过注意向 中注入随机性来减少输入图中的信息。在训练中, 中的正则项将尝试为所有边缘分配较大的随机性,但在分类损失 (相当于交叉熵损失)的驱动下,GSAT 可以学习减少在任务相关子图上的注意力的随机性。因此,提供模型解释的不是整个 ,而是具有随机减少注意力的部分,又名 。因此,当 GSAT 提供解释时,在实践中,我们可以根据 对所有的边进行排序,并使用那些排名靠前的边(如果需要,给定一定的预算)作为检测到的子图进行解释。正如实验中(Table 5)所示的那样,注入随机性对性能的贡献非常显著,而我们的正则化项的贡献( ),当我们将其与稀疏性驱动的 进行比较时(Fig. 7)。
GSAT与以前的方法有本质的不同,因为我们没有使用任何稀疏性约束,如 ,或 到 {0,1}来选择大小约束(或连通性约束)子图。我们实际上观察到,在边缘正则化( )中,设置 远离 ,也就是说,使 远离稀疏,通常会提供更稳健的解释。这与我们的直觉相吻合,即GIB根据定义并不对选定的子图做出任何假设,而只是限制了来自原始图的信息。我们的实验表明,即使在优化过程中没有利用这些假设,GSAT也显著优于基线,即使与标签相关的子图满足这些假设。如果与标签相关的子图确实断开或大小变化,GSAT的改进预计会更大。
4.4. Further Comparison on Interpretation Mechanism
PGExcwaner 和 GraphMask 在他们的模型中也具有随机性。然而,他们的主要目标是在一个离散的子图选择空间上实现基于梯度的搜索,而不是像 GSAT 那样控制信息。因此,他们在原则上并没有像我们的那样得出信息正则化( ),但采用稀疏性约束来提取一个直接用于解释的小子图 。
IB-subgraph 考虑使用 GIB 作为目标,但不注入任何随机性来生成 ,因此其选择的子图 是 的确定性函数。具体来说,IB-subgraph 采样图 来估计 并优化确定性函数 以最小化这种MI估计。在这种情况下, 简化为熵 ,这倾向于给出一个小尺寸的 ,因为小图的空间很小,并且有一个熵的下上界。相比之下, 在GSAT中是随机的,而 GSAT 主要通过注入随机性来增加 来实现GIB。
4.5 Guaranteed Spurious Correlation Removal
GSAT可以消除训练数据中的虚假相关性,并保证了可解释性。我们可以证明,如果子图模式 与标签 之间存在对应关系,则模式 是 GIB 目标的最优解( )。
Theorem 4.1. Suppose each contains a subgraph such that is determined by in the sense that for some deterministic invertible function with randomness that is independent from . Then, for any , maximizes the GIB , where .
Proof. Consider the following derivation:
其中第三个相等式是由于 ,那么 没有比 持有更多的信息。
如果 , 的最大化 也可以最小化 , 的下界为 0。
是生成 的子图。这是因为(a) ,其中 独立于 ,所以 和 (b) ,其中 独立于 ,所以 。因此, 最大限度地提高了GIB ,其中 。
虽然 决定了 ,但在训练数据集中,数据 和 可能存在一些由环境造成的虚假相关性。也就是说, 可能与标签有一定的相关性,但这种相关性是虚假的,并不是决定其标签的真正原因(如 所示)。一个经过训练的 ,通过MI最大化来预测 的模型可以捕捉到这种虚假的相关性。如果这种相关性在测试阶段发生变化,模型就会出现性能下降。
然而,Theorem 4.1 表明,GSAT通过优化GIB目标,能够仅通过提取 来解决上述问题,从而消除了虚假的相关性,也提供了有保证的可解释性。
4.6. Fine-tuning and Interpreting a Pre-trained Model
GSAT还可以微调和解释一个预先训练好的GNN。给定一个由 预先训练的 ,GSAT 可以通过 , 进行微调,通过初始化 和 中使用的GNN作为预先训练的模型 。
我们观察到,这个框架几乎永远不会损害原始的预测性能(有时甚至会提高它)。此外,与从头开始训练GNN相比,该框架往往能获得更好的解释结果。
5 Experiments
由上一节可见,GSAT架构简单直接,但同时其性能又具有理论保障。这一章节将通过实验结果具体展示GSAT的可解释能力、泛化能力和各模块的消融实验结果。
可解释性
作者们在真实数据集和合成数据集上都对GSAT的可解释性进行了评估。作者们基于这些数据集中已知的解释标注对每个方法的解释结果评估了ROC AUC。如下图,GSAT对比过去的可解释工作,在6个数据集上提升了至多20%、平均12%的可解释性能。
泛化性能
由于GSAT能够帮助去除伪相关性,它同时也能帮助提升模型的分类泛化能力。如下图,GSAT在11个数据集上提升了平均3%的模型准确率,并且在OGBG-MolHiv榜单上达到SOTA(在不使用手工设计的专家特征的模型中)。
OOD泛化性能(伪相关性移除)
为了对比GSAT的移除伪相关特征的能力,作者们同时提供了和不变因果特征学习的方法的直接对比。如下图,可见GSAT能够在不利用因果分析框架的情况下,以更为简单的架构提升平均12%的OOD泛化能力。
消融实验
作者们提供了GSAT中各个模块的消融实验结果,如下表,可见当不注入随机性(NoStoch),或者不添加正则项()时,模型效果均会大幅下降。而当不注入随机性时,模型效果将遭受最大的下降。这一消融实验展示了注入的随机性在GSAT中扮演着极其重要的角色。
作者们同样实验了将从信息瓶颈中推导得来的KL散度正则项替换成过去的方法常用的 正则。下图对各正则项的系数进行了网格搜索,可见文中提出的信息正则项显著优于正则。
6 Conclusion
图随机注意(GSAT)是一种建立可解释图学习模型的新型注意机制。GSAT注入了随机性来阻断与标签无关的信息,并利用减少随机性来选择与标签相关的子图。这种基本原理是基于信息瓶颈原理建立的。GSAT具有许多具有变革性的特征。例如,它消除了图学习解释中的稀疏性、连续性或其他潜在的偏差假设,而没有性能衰减。它还可以消除伪相关,更好地提高模型泛化。作为一个副产品,我们还从信息瓶颈的优化角度揭示了事后解释方法背后的一个潜在的严重问题。
7 GSAT 简单版
在很长的一段时间里,人们认为注意力机制无法提供较好的可解释性,尤其是在图学习领域。而该论文的作者们提出了一种随机注意力机制,并特别的在图学习领域进行了推导和评估,作者们称该机制为GSAT,即图随机注意力(Graph Stochastic Attention)。后续实验表明该机制能够同时提供强大的可解释能力和泛化能力。
机制原理
随机注意力机制,顾名思义,即是在学习注意力时注入随机性。下图提供了其在图学习领域的一个例子。该任务目标是预测图中是否存在五节点环(由图中粉色节点包围),这些环中的边是自然则是对预测结果重要的边。该机制原理的直觉如下:
- 首先,每一条边将会习得一个 之间的注意力权重,该权重将指代每一条边在训练中的抽样概率。一个正则项需要被引入来鼓励每一条边习得较小的抽样概率,即维持较大的随机性。如下图中间样本。
- 随后,倘若对预测结果重要的边存在较大的随机性,那么它们在训练中将会被过于频繁的丢弃,而这将极大的影响分类损失(交叉熵)。因此被分类损失推动,重要的边最后则会维持较小的随机性,即习得较大的抽样概率(理想情况下接近于 )。如下图右侧样本。
- 最后,每条边的随机程度指代其对预测性能的重要程度,而越重要的边应有越大的抽样概率。如下图右侧样本虚线框中的子图即为低随机性的子图,它代表着对预测最为重要的子图。
训练目标
现在的问题即是,上述的正则项应当如何选取呢?事实上这也非常直觉。因为作者们的目标是控制训练图中的随机性,而从信息论的角度来说,作者们即是希望控制图中的信息量。那么一个显而易见的选择就是信息瓶颈理论(Information Bottleneck Principle)。通过注入信息瓶颈,GSAT能够天然的控制图中的信息量,从而达到预期的效果。
具体而言,图信息瓶颈损失可以写作:
其中 代表两个随机变量之间的互信息量(Mutual Information), 是一个正则系数, 代表 信息瓶颈注入的强度。 是一个负责从原图 中提取子图 的模型, 而 则是负责对提取出的子 图 进行下游任务 的预测的模型。
互信息量自身不易优化,作者们为上述目标中的两项分别推导出了变分上界(Variational Upper Bound)来优化该目标。
- 对于第一项 , 易得其变分上界即为 , 而这事实上就是 基于 进行预测后产生的交叉熵损失。
- 对于第二项中的 , 易得其变分上界为 , 其中 即 为 基于 得到的每条边的采样概率; 而 即是对各边采样概率分布的一个正则, 因为该 散度本质上将鼓励习得的每条边的采样分布逼近 的分布。举例来说, 倘若 是 一个参数为 的伯努利分布, 那么这一项则将鼓励每条边的采样概率接近 , 而这正好符合 作者们对随机注意力机制中正则项工作原理的期待。
有保障的可解释性和OOD泛化(伪相关性去除)能力
由上文可知,最终GSAT的训练目标即是一个分类损失(鼓励高分类性能),加上一个KL散度的正则项(鼓励高随机性)。理想情况下,我们期待当模型仅将重要的边维持较小的随机性时,该训练目标应该被最小化,因为在这种情况下我们可能可以在达到最高分类性能的同时,取得最高的整体随机性。而作者们则在论文的定理4.1中证明了这一点,使得GSAT的性能具有理论保障。
具体来说,论文中定理4.1表明: 给定一个任务,如果我们假设输入图 中包含一个子图 , 并 且其标签 将由下式决定: , 其中 是一个可逆的且无随机性的函数, 是与 无关的随机噪声。那么对于任何的 , 能够最小化上文提出的信息瓶颈损失。
这意味着GSAT能够在不利用因果分析工具的情况下,天然的找出真正重要的子图 ,并且移除可能存在的伪相关特征,从而提供有保障的可解释性和OOD泛化能力。
GSAT 模型架构
有了文中提出的两个变分上界, 那么GSAT的模型架构问题则变得一目了然。现在只需要对 和 进行适当的参数化。
直觉来说,如下图:
- 的输入是原始图 ,其对每一条边输出一个注意力权重,那么 显然可以是一个这 样工作的GNN。
- 的输入是子图 ,其输出对该样本的标签预测,那么 显然可以是一个这样工作的 GNN。
- 因此,在GSAT中, 将会接受原图 作为输入,然后输出每一条边的随机注意力的值。紧接 着,基于随机注意力的取值,一个子图 将会被采样出来。最后 将会被喂给 进行最后的 标签预测。
- 尽管 和 可以是两个不同的 ,但作者们发现这里用同一个 GNN 效果就足够好。另外,架构中 最后的采样操作本身是不可导的,因此作者们提出利用Gumbel-softmax Trick来重 参数化这一步骤,使其可导。
因上求缘,果上努力~~~~ 作者:别关注我了,私信我吧,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/16534336.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
2021-08-01 输出单元