Generalised f-Mean Aggregation for Graph Neural Networks
概
基于 MPNN 架构的 GNN 主要在于 aggregator 和 update function 两部分, 一般来说后者是参数化的主要方式. 本文提出一种新的参数化 aggregator 的方法, 能够覆盖绝大部分经典的 aggregators.
符号说明
- \(\mathcal{X} = \{x_1, x_2, \ldots, x_n\}\), 一批结点的 embedding, \(|\mathcal{X}| = n\), node embedding, \(x_i \in \mathbb{R}^d\);
- \(\odot: \mathbb{R}^{n \times d}\), aggregation function.
GenAgg
-
作者提出这样的 generalized f-mean:
\[f^{-1}(\frac{1}{n} \sum_{i} f(x_i)), \]比如 \(f(x) = \frac{1}{x}\) 的时候, 有
\[ \odot(\mathcal{X}) = \frac{n}{\sum \frac{1}{x_i}} \]为 harmonic mean.
-
不过这种定义太强了, 有些常用的 aggregator (如 'sum') 没法满足, 所以本文首先提出一种 augmented f-mean:
\[f^{-1} \bigg( n^{\alpha - 1} \sum_{i} f(x_i - \beta \mu), \bigg) \]其中 \(\alpha, \beta\) 是可学习的参数, \(\mu\) 可以是均值 \(\mu = \frac{1}{n} \sum x_i\).
上表列出不同 \(\langle f, \alpha, \beta \rangle\) 下的 aggregation function. -
除了给定具体的 \(f\) 外, 我们也可以直接用神经网络去拟合 \(f\), 为了保证 \(f\) 的可逆性, 我们可以用 normalizing flows 的技术实现.
-
不过, 另一方面, 我们可以采用很简单的方式, 用两个不同的 MLP 来分别作为 \(f, f^{-1}\), 同时在训练的时候施加如下的约束:
\[ \mathcal{L}_{inv}(\theta_1, \theta_2) = \mathbb{E}\bigg[ \big( |f_{\theta_2}^{-1}(f_{\theta_1}(x))| - |x| \big)^2 \bigg]. \] -
注意, 一般来说, \(f: \mathbb{R}^1 \rightarrow \mathbb{R}^1\), 为了增加一些表达能力, 我们也可以用
\[ f: \mathbb{R}^{1} \rightarrow \mathbb{R}^d, \quad f^{-1}: \mathbb{R}^{d} \rightarrow \mathbb{R}^1. \]
代码
[official]