如何解决图神经网络(GNN)训练中过度平滑的问题
如何解决图神经网络(GNN)训练中过度平滑的问题?
转自知乎
https://www.zhihu.com/question/346942899/answer/835292740
泻药..首先要搞清楚图神经网络不能加深的原因是什么。常见的原因有三种:1)数据集太小,overfitting的问题,在一些数据上training acc为100%的大概率是这个问题,需要通过防止过拟合的技术来解决 2)vanishing gradient,这是CNN里一样存在的问题,当层数太深导致网络的参数不能得到有效的训练。这个问题可以加skip connections可以有效解决 3)over smoothing
其他同学@也提到了我们ICCV Oral的工作:DeepGCNs,这个工作主要是解决了vanishing gradient和over smoothing的问题,最开始是在点云上做的实验,正在做的TPAMI版本我们把14层的图网络MRConv用到了PPI数据,达到了F1 score 99.4的效果,是目前的start-of-the-art。PPI部分的实验代码近期会开源。
点云实验的代码、论文、slides都已开源。论文还有很多可以改善的地方,我们也还在做一些后续工作,欢迎交流:
Arxiv paper:
DeepGCNs: Can GCNs Go as Deep as CNNs?Github:
Tensorflow:
lightaime/deep_gcnsPytorch:
lightaime/deep_gcns_torch
都说GNN实际是个热传导,所以如果导热率太高,时间太长,最终就是温度达到单一温度。所以要降低导热率,或者缩短传导时间,才能形成有局部特征的分布模式。从消息传递的角度,就是要增加势能函数的差异性,或者说是降低系统温度,以及减少消息传递的循环次数。
更正一下题目中的几个小误区:
原题:如何解决图神经网络(GNN)训练中过度平滑的问题?即在图神经网络的训练过程中,随着网络层数的增加和迭代次数的增加,每个节点的隐层表征会趋向于收敛到同一个值(即空间上的同一个位置)。
不是所有图神经网络都有 over-smooth 的问题,例如,基于 RandomWalk + RNN、基于 Attention 的模型大多不会有这个问题,是可以放心叠深度的~只有部分图卷积神经网络会有该问题。
不是每个节点的表征都趋向于收敛到同一个值,更准确的说,是同一连通分量内的节点的表征会趋向于收敛到同一个值。这对表征图中不通簇的特征、表征图的特征都有好处。但是,有很多任务的图是连通图,只有一个连通分量,或较少的连通分量,这就导致了节点的表征会趋向于收敛到一个值或几个值的问题。
注:在图论中,无向图的连通分量是一个子图,其中任何两个顶点通过路径相互连接。
为什么 GCN 中会存在 over-smooth 的问题
首先,回顾一下全连接神经网络和 Kipf 图卷积神经网络的公式:
其中, 为激活函数, 为节点特征, 为训练参数, , 为邻接矩阵, , 为图中的所有节点。可以发现图卷积神经网络只多了对节点信息进行汇聚的权重 。从 (无归一化)到 (归一化),再到 (对称归一化),对于该权重的研究已然汗牛充栋。
学有余力的同学可以往下看通式上 over-smooth 的证明,这里先以 为例,进行一个直观的解释:
首先,中间层的 由任务相关的 反向传播进行优化,可以理解为任务相关的模式提取能力,我们将其统一在图卷积后进行,多层卷积公式可以近似为:
其中, 可以看作被提取的多个隐藏层。化简该式:
其中,邻接矩阵的幂, 表示节点 和节点 之间长度为 的 walk 的数量。而它的度, 代表节点 到所有节点之间长度为 的 walk 的数量。
这时, 则代表以节点 为起点,随机完成 步的 walk 最后抵达节点 的概率。
随着 walk 步数的增多,远距离节点的抵达难度越来越小,被随机选中的概率越来越大。当 时,连通分量中的节点 到达连通分量中任意节点的概率都趋于一致,为 ,其中 代表连通分量中节点的总数,即 ,其中 、 代表连通分量的邻接矩阵和度矩阵。
令连通分量中的特征向量为 ,且 , 代表连通分量中节点的特征维度。节点信息的汇聚可以表示为:
连通分量中每个节点的特征都为所有节点特征的平均,也就是我们开始的时候说的,同一连通分量内的节点的表征趋向于收敛到同一个值。
在感性地认识到图卷积与连通分量之间的关联后,有的工作想到利用特征分解(特征向量对应连通分量)给出 over-smooth 定理的证明[1]:
over-smooth 定理:假设图 由 个连通分量 构成,其中第 个连通分量可以用向量 表示:
那么,当图中不存在二分连通分量时,有:
其中, 和 表示线性组合 的系数,且:
本想写自己的证明过程,但由于篇幅较长喧宾夺主,有机会再贴~
如何解决 over-smooth 的问题
在了解为什么 GCN 中会存在 over-smooth 问题后,剩下的工作就是对症下药了:
问题:图卷积会使同一连通分量内的节点的表征会趋向于收敛到同一个值。
- 针对“图卷积”:在当前任务上,是否能够使用 RNN + RandomWalk(数据为图结构,边已然存在)或是否能够使用 Attention(数据为流形结构,边不存在,但含有隐式的相关关系)?
- 针对“同一连通分量内的节点”:在当前任务上,是否可以对图进行 cut 等预处理?如果可以,将图分为越多的连通分量,over-smooth 就会越不明显。极端情况下,节点都不相互连通,则完全不存在 over-smooth 现象(但也无法获取周围节点的信息)。
如果上述方法均不适用,仍有以下 deeper 和 wider 的措施可以保证 GCN 在过参数化时对模型的训练和拟合不产生负面影响。个人感觉,这类方法的实质是不同深度的 GCN 模型的 ensamble:
巨人肩膀上的模型深度 —— residual 等
Kipf 在提出 GCN 时,就发现了添加更多的卷积层似乎无法提高图模型的效果,并通过试验将其归因于 over-smooth:多层 GCN 可能导致节点趋同化,没有区别性。但是,早期的研究认为这是由 GCN 过分强调了相邻节点的关联而忽视了节点自身的特点导致的。 所以 Kipf 给出的解决方案是添加残差连接[2],将节点自身特点从上一层直接传输到下一层:
在这个思路下,陆续有工作借鉴 DenseNet,将 residual 连接替换为 dense 连接,提出了自己的 module [3][4]:
其中, 表示拼接节点的特征向量。
最近,也有些工作认为直接将使用残差连接矫枉过正,残差模块完全忽略了相邻节点的权重,因而选择在 的基础上,对节点自身进行加强[5]:
在此基础上,作者进一步考虑了相邻节点的数量,提出了新的正则化方法:
另辟蹊径的模型宽度 —— multi-hops 等
随着图卷积渗透到各个领域,一些研究开始放弃深度上的拓展,选择效仿 Inception 的思路拓宽网络的宽度,通过不同尺度感受野的组合对提高模型对节点的表征能力。N-GCN[6]通过在不同尺度下进行卷积,再融合所有尺度的卷积结果得到节点的特征表示:
其中, , 表示拼接节点的特征向量。原文中尝试了 和 等不同的归一化方法对当前节点 阶临域的进行信息汇聚,取得了还不错的效果。
也有一些工作认为 GCN 的各层的卷积结果是一个有序的序列:对于一个 层的 GCN,第 层捕获了 -hop 邻居节点的信息,其中 ,相邻层 和 之间有依赖关系。因而,这类方法选择使用 RNN 对各层之间的长期依赖建模[7]:
即为:
随着图卷积的日益成熟,深层的图卷积已经在各个领域开花结果啦~ 相信在不久的将来,pruning 和 NAS 还会碰撞出新的火花,童鞋们加油呀!另外,有的同学私信想看我的论文中是怎样处理 over-smooth 的~可是由于写作技巧太差我的论文还没发粗去(最开始导师都看不懂我写的是啥,感谢一路走来没有放弃我的导师和师兄,现在已经勉强能看了),等以后有机会再分享叭~