Neural Bellman-Ford Networks A General Graph Neural Network Framework for Link Prediction

Zhu Z., Zhang Z., Xhonneux L. and Tang J. Neural Bellman-Ford networks: A general graph neural network framework for link prediction. NIPS, 2021.

一种通用的 inductive 的图推理方式, 可用于 homogeneous graph 或 multi-relational graphs (如 knowledge graph).

符号说明

  • \(\mathcal{G} = (\mathcal{V}, \mathcal{E}, \mathcal{R})\), multi-relational graph;
  • \(\mathcal{V}\), node (entity) set;
  • \(\mathcal{R}\), edge (relation) type set;
  • \(\mathcal{E}\), edge (relation) set, 其中的元素表示为 \(e \in \mathcal{E} = (u, r, v)\), \(u, v \in \mathcal{V}, r \in \mathcal{R}\);
  • \(\mathcal{N}(u)\) 表示结点 (实体) \(u\) 的一阶邻居
  • \(\bm{A}\), 邻接矩阵

Motivation

  • 本文讨论的主要任务是, 判断两个实体 \((u, v)\) 是否存在某个关系 (query relation) \(q \in \mathcal{R}\), 通俗地将就是 link prediction 任务, 在知识图谱中便是常闻的知识图谱补全任务.

  • 以往的方法主要涉及三种主要的技术路线: Path-based methods, Embedding methods, GNNs.

    1. 其中 Embedding 的方法如 DeepWalk, TransE, RotatE 等, 往往局限在 transductive 的常见下, 对于未知的结点就难以做出判断了.
    2. Path-based 的方法由于往往是基于图的拓扑结构信息的, 所以 inductive setting 下发挥的比较好. 不过有一个问题是, 显式地计算这些图的性质是非常耗时的.
    3. GNN 的方法, 对于不使用 node features (如 SEAL, GraIL) 的情况下, 就是 inductive 的, 虽然之前的方法的效率也并不太高.
  • Path Formulation: 之前的基于 Path 的方法, 往往是通过编码头实体 \(u\) 和尾实体 \(v\) 见的不同的 walks 的数量来得到最后的 pair representation \(\bm{h}_q(u, v)\), 本文将这些方法统一为如下的形式:

    \[\tag{1} \bm{h}_q(u, v) = \bm{h}_q(P_1) \oplus \bm{h}_q(P_2) \oplus \cdots \oplus \bm{h}_q(P_{|\mathcal{P}_{uv}|}) \triangleq {\huge \oplus}_{P \in \mathcal{P}_{uv}} \bm{h}_q (P) \\ \bm{h}_q (P = (e_1, e_2, \ldots, e_{|P|})) = \bm{w}_q(e_1) \otimes \bm{w}_q (e_2) \otimes \cdots \otimes \bm{w}_q (e_{|P|}) \triangleq {\huge \otimes}_{i=1}^{|P|} \bm{w}_q (e_i). \]

    其中 \(\mathcal{P}_{uv}\) 表示 \(u \rightarrow v\) 的 paths 的集合, \(\bm{w}_q(e)\) 表示 edge \(e\) 的向量表示. \(\oplus\) 是可交换的 summation operator, \(\otimes\) 表示 muliplication operator (但是并不一定满足交换律), 不同的指标这两个符号会有具体的形式.

  • (1) 可以简写为:

    \[ \bm{h}_q(u, v) = {\huge \oplus}_{P \in \mathcal{P_{uv}}} {\huge \otimes}_{i=1}^{|P|} \bm{w}_q(e_i). \]


  • Katz index: \(\oplus = +, \otimes = \times\), 且 \(\bm{w}_q(e) = \beta w_e\).

proof:

Katz index 的定义为:

\[\text{Katz}(u, v) = \sum_{t=1}^{\infty} \beta^t \bm{e}_u^T \bm{A}^t \bm{e}_v, \]

其中 \(\beta \in (0, 1)\) 是一个 attenuation factor, \(\bm{e}_u\) 为 one-hot 向量, 仅在 \(u\) 处为 1. 因为 \(\bm{e}_u^T \bm{A}^t \bm{e}_v\) 表示 \((u, v)\) 间长度为 \(t\) 的 path 的数目, 所以 Katz index 整体是统计 \((u, v)\) 间所有 path 的数目的指标.

根据条件我们有:

\[\begin{array}{ll} \bm{h}_q(u, v) &= \sum_{P \in \mathcal{P}_{uv}} \prod_{e \in P} \beta w_e \\ &= \sum_{t=1}^{\infty }\sum_{P \in \mathcal{P}_{uv}, |P|= t} \prod_{e \in P} \beta w_e \\ &= \sum_{t=1}^{\infty } \beta^t \sum_{P \in \mathcal{P}_{uv}, |P|= t} \prod_{e \in P} w_e, \end{array} \]

\(w_e \equiv 1\) 的时候就完全退化为了 Katz.


  • Personalized PageRank: \(\oplus = +, \otimes = \times\), 且 \(\bm{w}_q (e) = \alpha w_{uv} / \sum_{v' \in \mathcal{N}(u)} w_{uv'}\).

proof:

PPR 的定义为:

\[ \text{PPR}(u, v) = (1 - \alpha) \sum_{t=1}^{\infty} \alpha^t \bm{e}_u^T (\bm{D}^{-1} \bm{A})^t \bm{e}_v, \]

其中 \(\bm{D}\) 表示 degree matrix.

根据条件我们有

\[\begin{array}{ll} \bm{h}_q(u, v) &= \sum_{P \in \mathcal{P}_{uv}} \prod_{(a, b) \in P} \frac{\alpha w_{ab}}{\sum_{b' \in \mathcal{N}(a)} w_{ab'}} \\ &= \sum_{t=1}^{\infty } \sum_{P \in \mathcal{P}_{uv}, |P|=t} \prod_{(a, b) \in P} \frac{\alpha w_{ab}}{\sum_{b' \in \mathcal{N}(a)} w_{ab'}} \\ &= \sum_{t=1}^{\infty } \alpha^t \sum_{P \in \mathcal{P}_{uv}, |P|=t} \prod_{(a, b) \in P} \frac{w_{ab}}{\sum_{b' \in \mathcal{N}(a)} w_{ab'}} \\ &\propto \text{PPR}(u, v). \end{array} \]


  • Graph distance: \(\oplus = \min, \otimes = +\), 且 \(\bm{w}_q(e) = w_e\).

proof:

Graph distance 定义为 \((u, v)\) 间的最短路径长度.

根据条件, 我们有

\[\begin{array}{ll} \bm{h}_q(u, v) &= \min_{P \in \mathcal{P}_{uv}} \sum_{e \in P} w_{e} = \text{GD}(u, v). \end{array} \]


  • Widest path: \(\oplus = \max, \otimes = \min\), 且 \(\bm{w}_q (e) = w_e\).

proof:

Widest path 定义为

\[ \text{WP}(u, v) = \max_{P \in \mathcal{P}_{uv}} \min_{e \in P} w_e, \]

即找到 \((u, v)\) 间最大化 minimum edge weight 的 path.

根据上面的条件, 我们有

\[\begin{array}{ll} \bm{h}_q(u, v) &= \max_{P \in \mathcal{P}_{uv}} \min_{e \in P} w_{e} = \text{WP}(u, v). \end{array} \]


  • Most reliable path: \(\oplus = \max, \otimes = \times\), 且 \(\bm{w}_q (e) = w_e\).

Most reliable path 指的是 \((u, v)\) 间概率最大的 path:

\[\text{MRP}(u, v) = \max_{P \in \mathcal{P}_{uv}} \prod_{e \in P} w_e. \]

根据上面的条件, 我们有

\[\bm{h}_q(u, v) = \max_{P \in \mathcal{P}_{uv}} \prod_{e \in P} w_{e} = \text{MRP}(u, v). \]


NBFNet

  • 尽管公式 (1) 能够统一多种 path 的指标, 但是这些指标的计算复杂度都是很高的 (path 的数量随着长度增加而指数级地膨胀). 故而作者建议采取一种更加一般化的 Bellman-Ford algorithm:

    \[\tag{2} \bm{h}_q^{(0)} (u, v) \leftarrow \mathbb{I}_q(u = v), \\ \bm{h}_q^{(t)} (u, v) \leftarrow {\huge(} {\huge \oplus}_{(u', r, v) \in \mathcal{E}(v)} \bm{h}_q^{t-1}(u, u') \otimes \bm{w}_q (u', r, v) {\huge)} \oplus \bm{h}_q^{(0)} (u, v). \]

    \(\bm{w}_q(u, r, v)\) 为 edge \(e=(u, r, v)\) 的表示,
    此外 \(\mathbb{I}_q(u = v)\) 为一 indicator function:

    \[\mathbb{I}_q(u=v) = \left\{ \begin{array}{ll} \textcircled{{\small 1}}_q & u=v, \\ \textcircled{{\small 0}}_q & u\not= v. \end{array} \right., \]

    注意到 \(\textcircled{{\small 1}}_q, \textcircled{{\small 0}}_q\) 分别是 summation identitymultiplication identity. 如果整个系统满足半环的性质, 则上述的一些 path 的指标都可以通过 (2) 来解决 (证明请回看原文).

  • 如果我们放松半环的要求, 我们可以得到一个一般信息传播形式:

    \[\tag{3} \bm{h}_q^{(0)} (u, v) \leftarrow \text{INDICATOR}(u, v, q), \\ \bm{h}_q^{(t)} (u, v) \leftarrow \text{AGGREGATE} {\huge(\{} \text{MESSAGE} {\Big(} \bm{h}_x^{(t-1)}, \bm{w}_q(u', r, v) {\Big)} {\Big|} (u', r, v) \in \mathcal{E}(v) {\huge\}} \cup {\huge\{} \bm{h}_v^{(0)} {\huge\}} {\huge)}. \]

    对比 (2) 和 (3) 可以发现:

    \[\mathbb{I} \longrightarrow \text{INDICATOR}, \\ {\huge \oplus} \longrightarrow \text{AGGREGATE}, \\ {\otimes} \longrightarrow \text{MESSAGE}. \]

    其实总体来说, 和一般的 MPNN 的结构还是很像的, 只是:
    1. 特殊的 Indicator function, 而不是结点的特征;
    2. 最后得到的是 \((u, v)\) 在 query relation \(q\) 下的 pair representation, 而不是某个结点的表示.
    其实 (3) 可以这样理解, \(\bm{h}^{(0)}\) 给了一个起始的条件 (通常是只有结点 \(u\) 处有非零的向量表示), 然后这些信息在传播过程中确定 query \(q\) 下的合适的 target \(v\).

  • 一些建议的设计方案:

    1. Message function: 可以设计成知识图谱中的 relational operators (比如 scaling, translation 等);
    2. Aggregatie function: 可以是普通的 summation, max, min, 作者建议在其后跟随 'a linear transformation' + 'a non-linear activation';
    3. Indicator function: 用 \(\mathbb{I}_q(u=v)\), 且令

      \[ \textcircled{{\small 1}}_q := \bm{q} \quad \text{for each } q \in \mathcal{R}, \]

      为可训练的向量, 而 \(\textcircled{{\small 0}}_q \equiv \bm{0}\) (作者发现这样的效果比较好).
    4. 关于 edge 的建模方式, 作者建议如果关系数目 \(|\mathcal{R}|\) 比较多, 可以设定为

      \[ \bm{w}_q(u, r, v) = \bm{W}_r \bm{q} + \bm{b}_r, \]

      如果很少, 为了避免过拟合, 可以

      \[ \bm{w}_q(u, r, v) = \bm{b}_r. \]

  • 最后的 link prediction, 作者通过如下方式计算概率

    \[ p(v|u, q) = \text{Sigmoid}\Big( f\big( \bm{h}_q(u, v) \big) \Big). \]

    如果是 undirected graph (对于 homogeneous graph) 可以

    \[ p(v|u, q) = \text{Sigmoid}\Big( f\big( \bm{h}_q(u, v) + \bm{h}_q(v, u) \big) \Big). \]

  • 最后通过 BCE 进行训练.

代码

[official]

[official-PyG]

posted @ 2024-03-02 11:01  馒头and花卷  阅读(77)  评论(3编辑  收藏  举报