Cluster-GCN An Efficient Algorithm for Training Deep Convolution Networks
概
以往的 GraphSage, FastGCN 等方法, 虽然能够实现 mini-batch 的训练, 但是他们所采样的方式效率是很低: 所采样的点之间往往可能具有很少的边, 导致整体的结果非常稀疏. 本文提出了一种高效的采样方式, 首先将所有的点聚类, 再采样.
符号说明
- \(G = (\mathcal{V, E}, A)\), 图;
- \(|\mathcal{V}| = N\);
- \(X \in \mathbb{R}^{N \times F}\), 特征矩阵;
- GCN 的每一层可以表述为:\[Z^{(l+1)} = A' X^{(l)} W^{(l)}, \: X^{(l+1)} = \sigma(Z^{(l+1)}), \]其中 \(A'\) 是 normalized 邻接矩阵.
- 最后的损失可以表述为\[\tag{1} \mathcal{L} = \frac{1}{|\mathcal{Y}_L|} \sum_{i \in \mathcal{Y}_L} \text{loss}(y_i, z_i^L), \]其中 \(\mathcal{Y}_L\) 表示所有打了标签的结点的集合.
Motivation
-
(1) 是一个整体的在所有的打过标签的结点上的损失, 这在应对特别大规模的数据的时候就很麻烦了, 所以我们更希望的是采用 mini-batch 的方式:
\[\mathcal{L} = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \text{loss}(y_i, z_i^L). \] -
但是有一个问题, 就是图的聚合操作, 需要用到所选的结点邻居, 所以, 即使 \(|\mathcal{B}|\) 本身很小, 为了精准地计算 \(z^L\), 所需的结点也是很多的 (随着层数指数增长). 所以我们只能采样一批点, 然后在较小的邻接矩阵 \(\hat{A}\) 上做聚合操作.
-
倘若我们采用随机采样的方式, 就容易导致采样的点之间的 edges 很少 (因为我们很难保证恰好采样到那些关系比较紧密的结点). 假设采样的点为 \(\mathcal{B}\), 实际上就是 \(\|A_{\mathcal{B, B}}\|_0\) 很小, 这会使得训练效率异常低下.
Cluster-GCN
-
本文的思想很简单, 希望通过聚类, 先将结点切分为多个紧密联系的群体 (通过聚类算法 METIS):
\[[\mathcal{V}_1, \cdots, \mathcal{V}_c], \]则我们同样得到 \(c\) 个子图:
\[[\{\mathcal{V}_1, \mathcal{E}_1, A_{11}\}, \cdots, \{\mathcal{V}_c, \mathcal{E}_c, A_{cc}\}]. \] -
于是乎, 在实际上训练的时候, 我们可以直接选择某个子图 \(G_i\) 作为一个 batch 用于训练.
-
这种做法, 实际上相当于用
\[\bar{A} = \left[ \begin{array}{ccc} A_{11} & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_{cc} \\ \end{array} \right] \]去逼近 \(A\), 由于我们舍去了很多 Links, 必然会导致性能的下降.
-
故本文在实际中, 会选择一个 (比预想) 较大的 \(c\), 然后每次采样的时候, 从中选择 \(q\) 个 clusters 作为一个 batch.