Neural Bellman-Ford Networks A General Graph Neural Network Framework for Link Prediction
概
一种通用的 inductive 的图推理方式, 可用于 homogeneous graph 或 multi-relational graphs (如 knowledge graph).
符号说明
- \(\mathcal{G} = (\mathcal{V}, \mathcal{E}, \mathcal{R})\), multi-relational graph;
- \(\mathcal{V}\), node (entity) set;
- \(\mathcal{R}\), edge (relation) type set;
- \(\mathcal{E}\), edge (relation) set, 其中的元素表示为 \(e \in \mathcal{E} = (u, r, v)\), \(u, v \in \mathcal{V}, r \in \mathcal{R}\);
- \(\mathcal{N}(u)\) 表示结点 (实体) \(u\) 的一阶邻居
- \(\bm{A}\), 邻接矩阵
Motivation
-
本文讨论的主要任务是, 判断两个实体 \((u, v)\) 是否存在某个关系 (query relation) \(q \in \mathcal{R}\), 通俗地将就是 link prediction 任务, 在知识图谱中便是常闻的知识图谱补全任务.
-
以往的方法主要涉及三种主要的技术路线: Path-based methods, Embedding methods, GNNs.
- 其中 Embedding 的方法如 DeepWalk, TransE, RotatE 等, 往往局限在 transductive 的常见下, 对于未知的结点就难以做出判断了.
- Path-based 的方法由于往往是基于图的拓扑结构信息的, 所以 inductive setting 下发挥的比较好. 不过有一个问题是, 显式地计算这些图的性质是非常耗时的.
- GNN 的方法, 对于不使用 node features (如 SEAL, GraIL) 的情况下, 就是 inductive 的, 虽然之前的方法的效率也并不太高.
-
Path Formulation: 之前的基于 Path 的方法, 往往是通过编码头实体 \(u\) 和尾实体 \(v\) 见的不同的 walks 的数量来得到最后的 pair representation \(\bm{h}_q(u, v)\), 本文将这些方法统一为如下的形式:
\[\tag{1} \bm{h}_q(u, v) = \bm{h}_q(P_1) \oplus \bm{h}_q(P_2) \oplus \cdots \oplus \bm{h}_q(P_{|\mathcal{P}_{uv}|}) \triangleq {\huge \oplus}_{P \in \mathcal{P}_{uv}} \bm{h}_q (P) \\ \bm{h}_q (P = (e_1, e_2, \ldots, e_{|P|})) = \bm{w}_q(e_1) \otimes \bm{w}_q (e_2) \otimes \cdots \otimes \bm{w}_q (e_{|P|}) \triangleq {\huge \otimes}_{i=1}^{|P|} \bm{w}_q (e_i). \]其中 \(\mathcal{P}_{uv}\) 表示 \(u \rightarrow v\) 的 paths 的集合, \(\bm{w}_q(e)\) 表示 edge \(e\) 的向量表示. \(\oplus\) 是可交换的 summation operator, \(\otimes\) 表示 muliplication operator (但是并不一定满足交换律), 不同的指标这两个符号会有具体的形式.
-
(1) 可以简写为:
\[ \bm{h}_q(u, v) = {\huge \oplus}_{P \in \mathcal{P_{uv}}} {\huge \otimes}_{i=1}^{|P|} \bm{w}_q(e_i). \]
- Katz index: \(\oplus = +, \otimes = \times\), 且 \(\bm{w}_q(e) = \beta w_e\).
proof:
Katz index 的定义为:
其中 \(\beta \in (0, 1)\) 是一个 attenuation factor, \(\bm{e}_u\) 为 one-hot 向量, 仅在 \(u\) 处为 1. 因为 \(\bm{e}_u^T \bm{A}^t \bm{e}_v\) 表示 \((u, v)\) 间长度为 \(t\) 的 path 的数目, 所以 Katz index 整体是统计 \((u, v)\) 间所有 path 的数目的指标.
根据条件我们有:
\(w_e \equiv 1\) 的时候就完全退化为了 Katz.
- Personalized PageRank: \(\oplus = +, \otimes = \times\), 且 \(\bm{w}_q (e) = \alpha w_{uv} / \sum_{v' \in \mathcal{N}(u)} w_{uv'}\).
proof:
PPR 的定义为:
其中 \(\bm{D}\) 表示 degree matrix.
根据条件我们有
- Graph distance: \(\oplus = \min, \otimes = +\), 且 \(\bm{w}_q(e) = w_e\).
proof:
Graph distance 定义为 \((u, v)\) 间的最短路径长度.
根据条件, 我们有
- Widest path: \(\oplus = \max, \otimes = \min\), 且 \(\bm{w}_q (e) = w_e\).
proof:
Widest path 定义为
即找到 \((u, v)\) 间最大化 minimum edge weight 的 path.
根据上面的条件, 我们有
- Most reliable path: \(\oplus = \max, \otimes = \times\), 且 \(\bm{w}_q (e) = w_e\).
Most reliable path 指的是 \((u, v)\) 间概率最大的 path:
根据上面的条件, 我们有
NBFNet
-
尽管公式 (1) 能够统一多种 path 的指标, 但是这些指标的计算复杂度都是很高的 (path 的数量随着长度增加而指数级地膨胀). 故而作者建议采取一种更加一般化的 Bellman-Ford algorithm:
\[\tag{2} \bm{h}_q^{(0)} (u, v) \leftarrow \mathbb{I}_q(u = v), \\ \bm{h}_q^{(t)} (u, v) \leftarrow {\huge(} {\huge \oplus}_{(u', r, v) \in \mathcal{E}(v)} \bm{h}_q^{t-1}(u, u') \otimes \bm{w}_q (u', r, v) {\huge)} \oplus \bm{h}_q^{(0)} (u, v). \]\(\bm{w}_q(u, r, v)\) 为 edge \(e=(u, r, v)\) 的表示,
此外 \(\mathbb{I}_q(u = v)\) 为一 indicator function:\[\mathbb{I}_q(u=v) = \left\{ \begin{array}{ll} \textcircled{{\small 1}}_q & u=v, \\ \textcircled{{\small 0}}_q & u\not= v. \end{array} \right., \]注意到 \(\textcircled{{\small 1}}_q, \textcircled{{\small 0}}_q\) 分别是 summation identity 和 multiplication identity. 如果整个系统满足半环的性质, 则上述的一些 path 的指标都可以通过 (2) 来解决 (证明请回看原文).
-
如果我们放松半环的要求, 我们可以得到一个一般信息传播形式:
\[\tag{3} \bm{h}_q^{(0)} (u, v) \leftarrow \text{INDICATOR}(u, v, q), \\ \bm{h}_q^{(t)} (u, v) \leftarrow \text{AGGREGATE} {\huge(\{} \text{MESSAGE} {\Big(} \bm{h}_x^{(t-1)}, \bm{w}_q(u', r, v) {\Big)} {\Big|} (u', r, v) \in \mathcal{E}(v) {\huge\}} \cup {\huge\{} \bm{h}_v^{(0)} {\huge\}} {\huge)}. \]对比 (2) 和 (3) 可以发现:
\[\mathbb{I} \longrightarrow \text{INDICATOR}, \\ {\huge \oplus} \longrightarrow \text{AGGREGATE}, \\ {\otimes} \longrightarrow \text{MESSAGE}. \]其实总体来说, 和一般的 MPNN 的结构还是很像的, 只是:
1. 特殊的 Indicator function, 而不是结点的特征;
2. 最后得到的是 \((u, v)\) 在 query relation \(q\) 下的 pair representation, 而不是某个结点的表示.
其实 (3) 可以这样理解, \(\bm{h}^{(0)}\) 给了一个起始的条件 (通常是只有结点 \(u\) 处有非零的向量表示), 然后这些信息在传播过程中确定 query \(q\) 下的合适的 target \(v\). -
一些建议的设计方案:
- Message function: 可以设计成知识图谱中的 relational operators (比如 scaling, translation 等);
- Aggregatie function: 可以是普通的 summation, max, min, 作者建议在其后跟随 'a linear transformation' + 'a non-linear activation';
- Indicator function: 用 \(\mathbb{I}_q(u=v)\), 且令\[ \textcircled{{\small 1}}_q := \bm{q} \quad \text{for each } q \in \mathcal{R}, \]为可训练的向量, 而 \(\textcircled{{\small 0}}_q \equiv \bm{0}\) (作者发现这样的效果比较好).
- 关于 edge 的建模方式, 作者建议如果关系数目 \(|\mathcal{R}|\) 比较多, 可以设定为\[ \bm{w}_q(u, r, v) = \bm{W}_r \bm{q} + \bm{b}_r, \]如果很少, 为了避免过拟合, 可以\[ \bm{w}_q(u, r, v) = \bm{b}_r. \]
-
最后的 link prediction, 作者通过如下方式计算概率
\[ p(v|u, q) = \text{Sigmoid}\Big( f\big( \bm{h}_q(u, v) \big) \Big). \]如果是 undirected graph (对于 homogeneous graph) 可以
\[ p(v|u, q) = \text{Sigmoid}\Big( f\big( \bm{h}_q(u, v) + \bm{h}_q(v, u) \big) \Big). \] -
最后通过 BCE 进行训练.
代码
[official]