Graph Neural Networks with Adaptive Residual

Liu X., Ding J., Jin W., Xu H., Ma Y., Liu Z. and Tang J. Graph neural networks with adaptive residual. NIPS, 2021.

基于 UGNN 框架的一个更加鲁棒的改进.

符号说明

  • \(\mathbf{A} \in \mathbb{R}^{n \times n}\), 邻接矩阵;
  • \(\mathbf{D} = \text{diag}([d_1, d_2, \ldots, d_n]), \quad d_i = \sum_{j} A_{ij}\).
  • \(\mathbf{\tilde{A}} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{1/2}\);

AirGNN

  • 下面是在不同的图任务上的一个训练结果:

  • 可以发现, 残差连接可以帮助 GNNs 利用更多的层去区别正常的结点, 但是却使得在异常结点上的分类恶化.

  • 我们可以这样认为, 简单的没有残差连接的图网络能够平滑结点表示, 所以此时随着层数的加深, 对于异常结点的分类会更好. 相反, 如果加了残差连接, 最后的结点表示始终会受到一开始的异常结点表示的影响, 所以结果并不太好.

  • 但是, 我们也不能直接移除残差连接, 因为这是加深 GNN 的几乎必须的技巧.

  • 一般的 GCN 都可以归结为如下的方式:

    \[\mathbf{X}_{out} = \text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_F^2 + (1 - \lambda) \frac{1}{2} \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}). \]

  • \(\|\mathbf{X} - \mathbf{X}_{in}\|_F^2 = \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i\|_2^2\), 我们知道, \(\|\cdot\|_2^2\) 对于异常值是敏感的, 所以作者转而改写成如下的更加鲁棒的方式:

    \[\text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_{21} + (1 - \lambda) \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}), \]

    其中

    \[\|\mathbf{X} - \mathbf{X}_{in}\|_{21} := \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i \|_2. \]

  • 通过 proximal gradient descent 来求解上面的问题, 得到如下的迭代方式:

  • 一个直观的理解是:

    • 当结点 \(i\) 的特征异常的时候, 通常 \(\|\mathbf{Y}_i - (\mathbf{X}_{in})_i\|_2\) 比较大;
    • 这就导致 \(\beta_i\) 比较大;
    • 此时 \(\mathbf{X}_i^{k+1}\) 更多由它的邻居决定 (即 \(\mathbf{Y}_i^k\)), 否则由它本身 \(\mathbf{X}_{in}\) 决定.

代码

[official]

posted @ 2023-10-31 19:38  馒头and花卷  阅读(32)  评论(1编辑  收藏  举报