Stochastic Training of Graph Convolutional Networks with Variance Reduction

Chen J., Zhu J. and Song L. Stochastic training of graph convolutional networks with variance reduction. ICML, 2018.

我们都知道, GCN 虽然形式简单, 但是对于结点个数非常多的情形是不易操作的: 多层的卷积之后基本上每个结点的感受野都会变得非常大 (指数级上升), 这对导致 mini-batch 训练的思想在图的任务中是不那么普遍的. GraphSage 通过采样结点的方式缓解了这个问题, 但是不想以往的 mini-batch 那样有很好的收敛性保证. 本文主要解决的就是这两个问题.

符号说明

  • \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\), 图;
  • \(V = |\mathcal{V}|, E = |\mathcal{E}|\);
  • \(A\), 邻接矩阵;
  • \(\tilde{A} = A + I\);
  • \(P = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\);
  • graph convolution layer:

    \[Z^{(l+1)} = P H^{(l)} W^{(l)}, \\ H^{(l+1)} = \sigma(Z^{l+1}). \]

Motivation

  • 我们首先将任务场景限定在 node-level 的问题上, 假设训练集 \(\mathcal{V}_l \subset \mathcal{V}\), 我们希望根据这些结点以及他们的标签为其余的结点 \(\mathcal{V} \setminus \mathcal{V}_l\) 打标签;

  • 假设我们通过如下损失进行训练:

    \[\mathcal{L} = \frac{1}{\mathcal{V}_l} \sum_{v \in \mathcal{V}_l} f(y_v, z_v^{(L)}), \]

    则每次训练都要计算如下的梯度:

    \[\nabla \mathcal{L} = \frac{1}{\mathcal{V}_l} \sum_{v \in \mathcal{V}_l} \nabla f (y_v, z_v^{(L)}), \]

    作者认为这一步的计算是非常大的.

  • 一些方法, 比如 GraphSage 采用采样邻居的方式来模拟 mini-batch 的训练方式, 它相当于:

    1. 采样部分结点;
    2. 利用这些结点构建新的邻接矩阵 \(\hat{P}\);
    3. 然后通过如下方式进行更新:

      \[Z^{(l+1)} = \hat{P}^{(l)} H^{(l)} W^{(l)}, \\ H^{(l+1)} = \sigma(Z^{(l+1)}), \]

      这就大大减轻了计算.
  • 但是这种方式, 由于非线性 \(\sigma\) 的存在, 无法保证 \(\hat{P}\) 会收敛到 \(\hat{P}\), 也因此整体的训练方式也缺乏严格的收敛保证.

本文方法

  • 作者维护历史变量 \(\bar{H}^{(l)}\), 然后通过如下方式更新:

    \[Z^{(l+1)} = (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) W, \]

    由于 \(\bar{H}^{(l)}\) 本身是不带梯度的, 所以在反向计算梯度的时候:

    \[\begin{array}{ll} \mathrm{d} Z^{(l+1)} &= \mathrm{d} (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) W \\ &= (\hat{P}^{(l)} \mathrm{d} H^{(l)}) W \\ & \quad + (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) \mathrm{d} W, \end{array} \]

  • 说实话, 从训练的方式来看, 并没有减少计算量的感觉. 不过, 因为这种方式更加稳定, 所以方差比较小, 因此只需要为每个结点采样 2 个邻居就可以了, 这应该是计算量降低的主要原因.

代码

official

posted @ 2023-04-15 15:52  馒头and花卷  阅读(85)  评论(0编辑  收藏  举报