论文信息

论文标题:Self-Attention Graph Pooling
论文作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
论文来源:2019, ICML
论文地址:download 
论文代码:download

1 Preamble

   对图使用下采样 downsampling (pooling)。

2 Introduction

  图池化三种类型:

    • Topology based pooling;
    • Global pooling;
    • Hierarchical pooling;

  关于 Hierarchical pooling 聚类分配矩阵:

    S(l)=softmax(GNNl(A(l),X(l)))A(l+1)=S(l)A(l)S(l)(1)

  其中,S(l)Rnl×nl+1nl 代表 第 l 层的节点数量。

  gPool 取得了与 DiffPool 相当的性能,gPool 的存储复杂度为 O(|V|+|E|),而 DiffPool 需要 O(k|V|2),其中 VEk 分别表示顶点、边和池化率。gPool 使用一个可学习的向量 p 来计算投影分数,然后使用这些分数来选择排名靠前的节点。投影得分由 p 与所有节点的特征之间的点积得到。这些分数表示可以保留的节点的信息量。下面的公式大致描述了 gPool 中的池化过程:

    y=X(l)p(l)/p(l)idx=toprank(y,kN)A(l+1)=Aidx,idx(l)(2)

3 Method

  框架如下:

   

3.1 Self-Attention Graph Pooling

Self-attention mask

  本文使用图卷积来获得自注意分数:

    Z=σ(D~12A~D~12XΘatt)(3)

  其中:

    • 自注意得分 ZRN×1
    • 邻接矩阵 A~RN×N
    • 注意力参数矩阵 ΘattRF×1
    • 特征矩阵 XRN×F
    • 度矩阵 D~RN×N

  保留部分重要节点:

    idx=toprank(Z,kN)Zmask =Zidx(4)

  基于自注意得分 Z ,保留前 kN 个节点,其中 k(0,1] 代表着池化率,Zmask 是 feature attention mask。

Graph pooling

  获得新特征矩阵和邻接矩阵:

     X=Xidx,:Xout =XZmask Aout =Aidx,idx(5)

  其中,  哈达玛积。

Variation of SAGPool

  利用图特征矩阵 X 和拓扑结构 A ,计算注意力得分矩阵 Z 的通用形式:

    Z=σ(GNN(X,A))(6)

  比如  SAGPool augmentation ,加入二跳邻居信息:

    Z=σ(GNN(X,A+A2))(7)

  比如  SAGPool serial ,堆叠多层 GNN:

    Z=σ(GNN2(σ(GNN1(X,A)),A))(8)

  比如  SAGPool parallel ,平均多重注意力分数。M 个 GNN 的平均注意得分如下:

    Z=1Mmσ(GNNm(X,A))(9)

3.2 Model Architecture

  本节用来验证模块的有效性。

Convolution layer

  图卷积 GCN:

    h(l+1)=σ(D~12A~D~12h(l)Θ)(10)

  与 Eq.3 不同的是,ΘRF×F

Readout layer

  根据 JK-net architecture 的思想:

    s=1Ni=1Nximaxi=1Nxi(11)

  其中:

    • N 代表着节点的个数;
    • xi 代表着第 i 个节点的特征向量;

  代码:

x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

  即:平均池化和 最大池化进行拼接。

Global pooling architecture & Hierarchical pooling architecture

  对比如下:

  

  Model Code:

复制代码
class Net(torch.nn.Module):
    def __init__(self, args):
        super(Net, self).__init__()
        self.args = args
        self.num_features = args.num_features
        self.nhid = args.nhid
        self.num_classes = args.num_classes
        self.pooling_ratio = args.pooling_ratio
        self.dropout_ratio = args.dropout_ratio

        self.conv1 = GCNConv(self.num_features, self.nhid)
        self.pool1 = SAGPool(self.nhid, ratio=self.pooling_ratio)
        self.conv2 = GCNConv(self.nhid, self.nhid)
        self.pool2 = SAGPool(self.nhid, ratio=self.pooling_ratio)
        self.conv3 = GCNConv(self.nhid, self.nhid)
        self.pool3 = SAGPool(self.nhid, ratio=self.pooling_ratio)

        self.lin1 = torch.nn.Linear(self.nhid * 2, self.nhid)
        self.lin2 = torch.nn.Linear(self.nhid, self.nhid // 2)
        self.lin3 = torch.nn.Linear(self.nhid // 2, self.num_classes)

    def forward(self, data):
        # 读取每个 batch 中的图数据
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # 第一次做 Self-Attention Graph Pooling=======
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        # 第一次 Readout layer
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # 第二次做 Self-Attention Graph Pooling=======
        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        # 第二次 Readout layer
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # 第三次做 Self-Attention Graph Pooling=======
        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        # 第三次 Readout layer
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # 跳跃连接
        x = x1 + x2 + x3

        # MLP
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.log_softmax(self.lin3(x), dim=-1)
        return x
复制代码

  SAGPool Code:

复制代码
class SAGPool(torch.nn.Module):
    def __init__(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh):
        super(SAGPool,self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.score_layer = Conv(in_channels,1)
        self.non_linearity = non_linearity
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        #x = x.unsqueeze(-1) if x.dim() == 1 else x
        score = self.score_layer(x,edge_index).squeeze()
        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
复制代码

4 Experiments

数据集

  

基线实验

  

SAGPool 的变体

  

5 Conclusion

  本文提出了一种基于自注意的SAGPool图池化方法。我们的方法具有以下特征:分层池、同时考虑节点特征和图拓扑、合理的复杂度和端到端表示学习。SAGPool使用一致数量的参数,而不管输入图的大小如何。我们工作的扩展可能包括使用可学习的池化比率来获得每个图的最优聚类大小,并研究每个池化层中多个注意掩模的影响,其中最终的表示可以通过聚合不同的层次表示来获得。

修改历史

2022-05-08 创建文章
2022-08-12 修订文章

论文解读目录

posted @   别关注我了,私信我吧  阅读(1349)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· DeepSeek在M芯片Mac上本地化部署
Live2D
点击右上角即可分享
微信分享提示