Feature Overcorrelation in Deep Graph Neural Networks: A New Perspective

Jin W., Liu X., Ma Y., Aggarwal C. and Tang J. Feature overcorrelation in deep graph neural networks: a new perspective. In ACM International Conference on Knowledge Discovery and Data Mining (KDD), 2022.

GNN 有一个很严重的弊端: over-smoothing, 这会导致 GNN 的层数不能过深. 这篇文章指出, 影响网络性能的可能并不是 over-smoothing (或者说它并不罪魁祸首), 真正的问题是特征的 over-correlation. 于是作者通过最小化特征间的相关度, 最大化特征和初始特征的互信息来解决这一问题.

符号说明

  • G=(V,E,X), 图;
  • V={v1,v2,,vN}, 结点;
  • EV×V, 边;
  • XRN×d, 结点的特征;
  • A{0,1}N×N, 邻接矩阵;
  • 一般的 GNN layer:

    Hi,:(l)=Transform(Propagate(Hj,:(l1)|vjN(vi){vi}))

    其中 N(vi) 表示 vi 的一阶邻居.

over-correlation 的现象

  1. pearson correlation coefficient:

    ρ(x,y)=i=1N(xix¯)(yiy¯)i=1N(xix¯)2i=1N(yiy¯)2;

  2. 由此, 我们可定义特征 HRN×d 上维度间的相关度:

    ρ(H:,i,H:,j),i,j[d]:={1,2,,d};

  3. 由此定义整个特征 H 的一个相关度指标:

    Corr(H):=1d(d1)ij|p(H:,i,H:,j)|[0,1],

    注意到, 当 H:,i,H:,ji,j 是线性相关的时候, Corr(H) 达到极端的 1.

  4. 同时我们定义整个特征 H 上的一个平滑度:

    SMV(H):=1N(N1)ijD(Hi,:,Hj,:)[0,1],

    其中

    D(x,y)=12xxyy2,

    当所有的结点都成比例, 即 Hi,:=cHj,:, 此时有一个最光滑的情况 SMV(H)=0.

由上图所示, 当层数逐渐增加的时候, 结点的特征间的相关度 Corr(H) 会迅速上升, 最后达到接近 1 的峰值, 此时 GNN 几乎丧失了判断能力. 此外, 虽然在层数增加的过程中, SMV(H) 也在逐步下降, 但是并不如 Corr(H) 来的显著.

另一个很有意思的现象是, 当我们采用 Transform 为 MLPs, 并加多 MLP 的层数的时候, 网络的会逐步趋向过拟合. 此时如下图 (b) 所示, Corr 的增长非常迅速, 且无论 ReLU 是否采用, 而 SMV 则并不一定, 这也说明特征间的相关度更像是罪魁祸首. 此外, 过参数化的网络更容易招致这一点.

解决方法

  1. 作者希望直接最小化特征间的相关度:

    minH1N1(HH¯)T(HH¯)Cov MatrixIF2;

  2. 这类似于下列的标准化后的损失:

    D(H)=(HH¯)T(HH¯)(HH¯)T(HH¯)FIdF2,

    于是在各层上的损失可以归结为:

    LD=i=1K1D(H(i));

  3. 此外, 除了减少相关度外, 我们还希望特征 H 不丧失太多输入特征 X 的信息, 我们希望最大化二者的互信息:

    maxMI(H,X);

  4. 我们没法直接计算出二者的互信息 (因为我们并不知道分布), 故我们采用最大化它的一个下界:

    MI(H,X)EP(H,X)[f(H,X)]logEP(X)P(H)[ef(A,X)],

    这里 f(H,X) 为 energy function;

  5. 具体的, 这里

    M(H(k),X)=EP(hi(k),xi)[f(hi(k),xi)]+logEP(hi(k))P(xi)[ef(hi(k),xi)],

    其中 f(,) 是一个二分类函数:

    f(hi,xi)=σ(xiTWhi),

    话说, 从 energy 的角度来说, 应该没有 σ 吧. 最后在各层上的损失就为

    LM=i[t,2t,,K1tt]M(H(i),X),

    注意, 这里作者每隔 t 层加一个损失, 用于加速训练;

  6. 最后总的损失为

    L=Lclass+αLD+βLM.

代码

[official]

posted @   馒头and花卷  阅读(57)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 因为Apifox不支持离线,我果断选择了Apipost!
· 通过 API 将Deepseek响应流式内容输出到前端
历史上的今天:
2021-09-16 Hough Transform
点击右上角即可分享
微信分享提示