Graph Neural Networks with Adaptive Residual
概
基于 UGNN 框架的一个更加鲁棒的改进.
符号说明
- \(\mathbf{A} \in \mathbb{R}^{n \times n}\), 邻接矩阵;
- \(\mathbf{D} = \text{diag}([d_1, d_2, \ldots, d_n]), \quad d_i = \sum_{j} A_{ij}\).
- \(\mathbf{\tilde{A}} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{1/2}\);
AirGNN
-
下面是在不同的图任务上的一个训练结果:
-
可以发现, 残差连接可以帮助 GNNs 利用更多的层去区别正常的结点, 但是却使得在异常结点上的分类恶化.
-
我们可以这样认为, 简单的没有残差连接的图网络能够平滑结点表示, 所以此时随着层数的加深, 对于异常结点的分类会更好. 相反, 如果加了残差连接, 最后的结点表示始终会受到一开始的异常结点表示的影响, 所以结果并不太好.
-
但是, 我们也不能直接移除残差连接, 因为这是加深 GNN 的几乎必须的技巧.
-
一般的 GCN 都可以归结为如下的方式:
\[\mathbf{X}_{out} = \text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_F^2 + (1 - \lambda) \frac{1}{2} \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}). \] -
\(\|\mathbf{X} - \mathbf{X}_{in}\|_F^2 = \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i\|_2^2\), 我们知道, \(\|\cdot\|_2^2\) 对于异常值是敏感的, 所以作者转而改写成如下的更加鲁棒的方式:
\[\text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_{21} + (1 - \lambda) \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}), \]其中
\[\|\mathbf{X} - \mathbf{X}_{in}\|_{21} := \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i \|_2. \] -
通过 proximal gradient descent 来求解上面的问题, 得到如下的迭代方式:
-
一个直观的理解是:
- 当结点 \(i\) 的特征异常的时候, 通常 \(\|\mathbf{Y}_i - (\mathbf{X}_{in})_i\|_2\) 比较大;
- 这就导致 \(\beta_i\) 比较大;
- 此时 \(\mathbf{X}_i^{k+1}\) 更多由它的邻居决定 (即 \(\mathbf{Y}_i^k\)), 否则由它本身 \(\mathbf{X}_{in}\) 决定.
代码
[official]