论文解读(GCA)《Graph Contrastive Learning with Adaptive Augmentation》

论文信息

论文标题:Graph Contrastive Learning with Adaptive Augmentation
论文作者:Yanqiao Zhu、Yichen Xu3、Feng Yu4、Qiang Liu、Shu Wu、Liang Wang
论文来源:2021, WWW
论文地址:download
论文代码:download

摘要

  近年来,对比学习(CL)已成为一种成功的无监督图表示学习方法。大多数图CL方法首先对输入图进行随机增强,以获得两个图视图,并最大化两个视图中表示的一致性。尽管图CL方法得到了繁荣的发展,但图增强方案的设计——CL中的一个关键组成部分——仍然很少被探索。我们认为,数据增强方案应该保留图的内在结构和属性,这将迫使模型学习对不重要的节点和边缘的扰动不敏感的表示。然而,现有的方法大多采用统一的数据增强方案,如统一降边和统一变换特征,导致性能次优。在本文中,我们提出了一种新的具有自适应增强的图对比表示学习方法,该方法包含了图的拓扑和语义方面的各种先验。具体来说,在拓扑层面上,我们设计了基于节点中心性度量的增强方案来突出重要的连接结构。在节点属性级别上,我们通过向不重要的节点特征添加更多的噪声来破坏节点特征,以强制模型识别底层的语义信息。我们在各种真实世界的数据集上进行了广泛的节点分类实验。实验结果表明,我们提出的方法始终优于现有的先进基线,甚至超过一些监督的方法,这验证了所提出的对比框架自适应增强的有效性。

1-介绍

  出发角度:倾向于保持重要的结构和属性不变,同时干扰可能不重要的边连接和特征。

  自适应数据增强方面:

    • 拓扑结构:基于节点中心性度量,突出重要连接;
    • 语义信息:对不重要的节点属性添加噪声;

2 Method

2.1 Framework

  框架如下:

    

 

  形如 GRACE 架构。

  算法流程:

   

  编码器

    $\begin{aligned}\mathrm{GC}_{i}(\boldsymbol{X}, \boldsymbol{A}) &=\sigma\left(\hat{D}^{-\frac{1}{2}} \hat{\boldsymbol{A}} \hat{D}^{-\frac{1}{2}} \boldsymbol{X} \boldsymbol{W}_{i}\right)\quad\quad\quad(12) \\f(\boldsymbol{X}, \boldsymbol{A}) &=\mathrm{GC}_{2}\left(\mathrm{GC}_{1}(\boldsymbol{X}, \boldsymbol{A}), \boldsymbol{A}\right)\quad\quad\quad(13)\end{aligned}$

  损失函数

    $\mathcal{J}=\frac{1}{2 N} \sum\limits _{i=1}^{N}\left[\ell\left(\boldsymbol{u}_{i}, v_{i}\right)+\ell\left(v_{i}, \boldsymbol{u}_{i}\right)\right]\quad\quad\quad(2)$

  其中:

    $log {\large \frac{e^{\theta\left(u_{i}, v_{i}\right) / \tau}}{\underbrace{e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right) / \tau}}_{\text {positive pair }}+\underbrace{\sum_{k \neq i} e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{k}\right) / \tau}}_{\text {inter-view negative pairs }}+\underbrace{\sum_{k \neq i} e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{u}_{k}\right) / \tau}}_{\text {intra-view negative pairs }}}}\quad\quad\quad(1) $

2.2 Adaptive Graph Augmentation

2.2.1 Topology-level augmentation

  利用 $\text{Eq.3}$ 中的概率从原始边集合中采样一个边子集合

    $P\{(u, v) \in \widetilde{\mathcal{E}}\}=1-p_{u v}^{e}\quad\quad\quad(3)$

  其中:

    • $(u, v) \in \mathcal{E}$;
    • $p_{u v}^{e}$ 是删除边 $ (u, v)$ 的概率;
    • $\widetilde{\mathcal{E}}$ 将作为生成视图的边集合;

  分析知: $p_{u v}^{e}$ 应该反映边 $ (u, v)$ 的重要性,目的是大概率删除不重要的边,同时保留增强视图中重要的边。

  节点中心性量化了节点的重要性,本文为边 $(u, v)$ 定义边中心性 $w_{u v}^{e}$,用于衡量边$(u, v)$ 对两个相连节点的影响。给定节点中心性度量 $\varphi_{c}(\cdot): \mathcal{V} \rightarrow \mathbb{R}^{+}$,将边中心性定义为两个相邻节点中心性得分的均值,即 $w_{u v}^{e}=\left(\varphi_{c}(u)+\varphi_{c}(v)\right) / 2$。在有向图上,只使用尾部节点的中心性,即 $w_{u v}^{e}=\varphi_{c}(v) $,因为边的重要性通常是它们指向的节点。

  接下来,根据每条边的中心性值来计算它的概率。由于采用度作为节点中心性这种度量方法在不同数量级上变化差别过大,所以本文首先设置 $s_{u v}^{e}=\log w_{u v}^{e}$ 以缓解具有高度密集连接的节点的影响。然后通过将边中心性的值转换为概率:

    $p_{u v}^{e}=\underset{}{\text{min}}   \left(\frac{s_{\max }^{e}-s_{u v}^{e}}{s_{\max }^{e}-\mu_{s}^{e}} \cdot p_{e}, \quad p_{\tau}\right)\quad\quad\quad(4)$

  其中,$p_{e}$ 是一个控制去除边的总体概率的超参数,$s_{\max }^{e}$ 和 $\mu_{s}^{e}$ 是 $s_{u v}^{e}$ 的最大值和平均值。而 $p_{\tau}<1$ 是一个临界概率(cut-off probability),对于边中心性低(高概率删除)的边,采用 $p_{\tau}$ 删除,对于边中心性低的边(低概率删除),采用 ${\large \frac{s_{\max }^{e}-s_{u v}^{e}}{s_{\max }^{e}-\mu_{s}^{e}} \cdot p_{e}} $  删除。

  这里提供三种 节点中心性度量 方法:

  1、点度中心性(Degree centrality):节点度本身可以是一个中心性度量。在有向网络上,使用内度,因为有向图中的一个节点的影响主要是由指向它的节点赋予的。  

  2、特征向量中心性(Eigenvector centrality):基本思想是一个节点的中心性是相邻节点中心性的函数。也就是说,与你连接的人越重要,你也就越重要。

  3、PageRank中心性(PageRank centrality):基于有向图

  对于 PageRank 中心性分数计算公式如下:

    $\sigma=\alpha A D^{-1} \sigma+1\quad\quad\quad(5)$

  其中,$\sigma \in \mathbb{R}^{N}$ 是每个节点的 PageRank中心性得分的向量,$\alpha$ 是一个阻尼因子,它可以防止图中的 sinks 从连接到它们的节点中吸收所有 ranks。这里设置$\alpha=0.85$。对于无向图,我们对转换后的有向图执行PageRank,其中每条无向边都被转换为两条有向边。

  例子:

  

  从图中可以看出,三种方案存在细微差别,但都强调了连接两个教练(橙色节点)的边,而较少关注边缘节点。

2.2.2 Node-attribute-level augmentation

  节点特征隐藏:

    $\widetilde{\boldsymbol{X}}=\left[x_{1} \circ \tilde{\boldsymbol{m}} ; \boldsymbol{x}_{2} \circ \tilde{\boldsymbol{m}} ; \cdots ; \boldsymbol{x}_{N} \circ \widetilde{\boldsymbol{m}}\right]^{\top}$

  其中:$\widetilde{m}_{i} \sim \operatorname{Bern}\left(1-p_{i}^{f}\right)$,即用 $1-p_{i}^{f}$ 的概率取 $1$,用 $p_{i}^{f}$ 的概率取 $0$ ;

  这里 $p_{i}^{f}$ 应该反映出节点特征的第 $i$ 个维数的重要性。我们假设经常出现在有影响的节点中的特征维度应该是重要的,并定义特征维度的权重如下。

  对于稀疏的 one-hot 节点特征,即 $x_{u i} \in\{0,1\}$,对于任何节点 $u$ 和特征维 $i$,我们计算维度 $i$ 的权重为

    $w_{i}^{f}=\sum\limits _{u \in \mathcal{V}} x_{u i} \cdot \varphi_{c}(u)\quad\quad\quad(7)$

  其中,

    • $\varphi_{c}(\cdot)$ 是一个用于量化节点重要性的节点中心性度量;
    • 第一项 $x_{u i} \in\{0,1\}$ 表示节点 $u$ 中维度 $i $ 的出现;
    • 第二项 $\varphi_{i}(u)$ 表示每次出现的节点重要性;

  对于稠密、连续的节点特征 $\boldsymbol{x}_{u}$,本文用绝对值  $\left|x_{u i}\right|$  来测量节点  $u$  的  $i$  维的特征的值的大小:

    $w_{i}^{f}=\sum\limits _{u \in \mathcal{V}}\left|x_{u i}\right| \cdot \varphi_{c}(u)\quad\quad\quad(8)$

  与 Topology-level augmentation 类似,我们对权值进行归一化,以获得表示特征重要性的概率。形式上:

    ${\large p_{i}^{f}=\min \left(\frac{s_{\max }^{f}-s_{i}^{f}}{s_{\max }-\mu_{s}^{f}} \cdot p_{f}, p_{\tau}\right)} \quad\quad\quad(9)$

  其中,$s_{i}^{f}=\log w_{i}^{f}$,$s_{\max }^{f}$  和 $\mu_{s}^{f}$ 分别为 $ s_{i}^{f}$ 的最大值和平均值, $p_{f}$ 是控制特征增强的总体幅度的超参数。 

3 Experiment

数据集

  

  【 Wiki-CSAmazon-ComputersAmazon-PhotoCoauthor-CSCoauthor-Physics  】

节点分类

  基线实验:

     

消融实验

   

灵敏度分析:

    

4 Conclusion

  开发了一种自适应数据增强对比学习框架。 

修改历史

2021-04-12 创建文章
2022-06-14 精读

编码器

from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Encoder(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2, skip=False):
        super(Encoder, self).__init__()
        self.base_model = base_model
        assert k >= 2
        self.k = k
        self.skip = skip
        if not self.skip:
            self.conv = [base_model(in_channels, 2 * out_channels).jittable()]
            for _ in range(1, k - 1):
                self.conv.append(base_model(2 * out_channels, 2 * out_channels))
            self.conv.append(base_model(2 * out_channels, out_channels))
            self.conv = nn.ModuleList(self.conv)
            self.activation = activation
        else:
            self.fc_skip = nn.Linear(in_channels, out_channels)
            self.conv = [base_model(in_channels, out_channels)]
            for _ in range(1, k):
                self.conv.append(base_model(out_channels, out_channels))
            self.conv = nn.ModuleList(self.conv)

            self.activation = activation

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        if not self.skip:
            for i in range(self.k):
                x = self.activation(self.conv[i](x, edge_index))
            return x
        else:
            h = self.activation(self.conv[0](x, edge_index))
            hs = [self.fc_skip(x), h]
            for i in range(1, self.k):
                u = sum(hs)
                hs.append(self.activation(self.conv[i](u, edge_index)))
            return hs[-1]

节点度重要性

from torch_geometric.utils import degree, to_undirected
def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])    # 转换成无向图之后,计算节点的度
    deg_col = deg[edge_index[1]].to(torch.float32)   # 计算节点的入度,入度越大,则节点越重要
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())
    return weights

PageRank  重要性

from torch_scatter import scatter
def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)
    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        agg_msg = scatter(edge_msg, edge_index[1], reduce='sum')
        x = (1 - damp) * x + damp * agg_msg
    return x

def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10):
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())
    return weights

pr_drop_weights(data.edge_index, aggr='sink', k=200).to(device)

特征向量 重要性

from torch_geometric.utils import to_networkx
import networkx as nx
def eigenvector_centrality(data):
    graph = to_networkx(data)
    x = nx.eigenvector_centrality_numpy(graph)
    x = [x[i] for i in range(data.num_nodes)]
    return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device)
def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc = evc.where(evc > 0, torch.zeros_like(evc))
    evc = evc + 1e-8
    s = evc.log()
    edge_index = data.edge_index
    s_row, s_col = s[edge_index[0]], s[edge_index[1]]
    s = s_col
    return (s.max() - s) / (s.max() - s.mean())

drop_weights = evc_drop_weights(data).to(device)

边丢弃自适应数据增强

def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
    return edge_index[:, sel_mask]节点

节点丢弃自适应数据增强

def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w.repeat(x.size(0)).view(x.size(0), -1)
    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)
    x = x.clone()
    x[drop_mask] = 0.
    return x

特征丢弃自适应数据增强

def feature_drop_weights(x, node_c):
    x = x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())
    return s

def feature_drop_weights_dense(x, node_c):
    x = x.abs()
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())
    return s
def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w
    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)
    x = x.clone()
    x[:, drop_mask] = 0.
    return x

  

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


论文解读目录

posted @ 2022-04-12 19:47  多发Paper哈  阅读(750)  评论(0编辑  收藏  举报
Live2D