Generalised f-Mean Aggregation for Graph Neural Networks

Kortvelesy R., Morad S. and Prorok A. Generalised f-mean aggregation for graph neural networks. NIPS, 2023.

基于 MPNN 架构的 GNN 主要在于 aggregator 和 update function 两部分, 一般来说后者是参数化的主要方式. 本文提出一种新的参数化 aggregator 的方法, 能够覆盖绝大部分经典的 aggregators.

符号说明

  • \(\mathcal{X} = \{x_1, x_2, \ldots, x_n\}\), 一批结点的 embedding, \(|\mathcal{X}| = n\), node embedding, \(x_i \in \mathbb{R}^d\);
  • \(\odot: \mathbb{R}^{n \times d}\), aggregation function.

GenAgg

  • 作者提出这样的 generalized f-mean:

    \[f^{-1}(\frac{1}{n} \sum_{i} f(x_i)), \]

    比如 \(f(x) = \frac{1}{x}\) 的时候, 有

    \[ \odot(\mathcal{X}) = \frac{n}{\sum \frac{1}{x_i}} \]

    为 harmonic mean.

  • 不过这种定义太强了, 有些常用的 aggregator (如 'sum') 没法满足, 所以本文首先提出一种 augmented f-mean:

    \[f^{-1} \bigg( n^{\alpha - 1} \sum_{i} f(x_i - \beta \mu), \bigg) \]

    其中 \(\alpha, \beta\) 是可学习的参数, \(\mu\) 可以是均值 \(\mu = \frac{1}{n} \sum x_i\).
    上表列出不同 \(\langle f, \alpha, \beta \rangle\) 下的 aggregation function.

  • 除了给定具体的 \(f\) 外, 我们也可以直接用神经网络去拟合 \(f\), 为了保证 \(f\) 的可逆性, 我们可以用 normalizing flows 的技术实现.

  • 不过, 另一方面, 我们可以采用很简单的方式, 用两个不同的 MLP 来分别作为 \(f, f^{-1}\), 同时在训练的时候施加如下的约束:

    \[ \mathcal{L}_{inv}(\theta_1, \theta_2) = \mathbb{E}\bigg[ \big( |f_{\theta_2}^{-1}(f_{\theta_1}(x))| - |x| \big)^2 \bigg]. \]

  • 注意, 一般来说, \(f: \mathbb{R}^1 \rightarrow \mathbb{R}^1\), 为了增加一些表达能力, 我们也可以用

    \[ f: \mathbb{R}^{1} \rightarrow \mathbb{R}^d, \quad f^{-1}: \mathbb{R}^{d} \rightarrow \mathbb{R}^1. \]

代码

[official]

posted @ 2023-12-27 17:03  馒头and花卷  阅读(13)  评论(0编辑  收藏  举报