论文解读(SAGPool)《Self-Attention Graph Pooling》
论文信息
论文标题:Self-Attention Graph Pooling
论文作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
论文来源:2019, ICML
论文地址:download
论文代码:download
1 Preamble
对图使用下采样 downsampling (pooling)。
2 Introduction
图池化三种类型:
-
- Topology based pooling;
-
Global pooling;
- Hierarchical pooling;
关于 Hierarchical pooling 聚类分配矩阵:
$\begin{array}{j}S^{(l)}=\operatorname{softmax}\left(\mathrm{GNN}_{l}\left(A^{(l)}, X^{(l)}\right)\right) \\A^{(l+1)}=S^{(l) \top} A^{(l)} S^{(l)}\end{array} \quad\quad\quad\quad(1)$
其中,$S^{(l)} \in \mathbb{R}^{n_{l} \times n_{l+1}}$ ,$n_{l}$ 代表 第 $l$ 层的节点数量。
gPool 取得了与 DiffPool 相当的性能,gPool 的存储复杂度为 $\mathcal{O}(|V|+|E|)$,而 DiffPool 需要 $\mathcal{O}\left(k|V|^{2}\right)$,其中 $V$、$E$ 和 $k$ 分别表示顶点、边和池化率。gPool 使用一个可学习的向量 $p$ 来计算投影分数,然后使用这些分数来选择排名靠前的节点。投影得分由 $p$ 与所有节点的特征之间的点积得到。这些分数表示可以保留的节点的信息量。下面的公式大致描述了 gPool 中的池化过程:
$\begin{array}{l} y=X^{(l)} \mathbf{p}^{(l)} /\left\|\mathbf{p}^{(l)}\right\|\\ \mathrm{idx}=\operatorname{top}-\operatorname{rank}(y,\lceil k N\rceil)\\A^{(l+1)}=A_{\mathrm{idx}, \mathrm{idx}}^{(l)}\end{array} \quad\quad\quad\quad(2)$
3 Method
框架如下:
3.1 Self-Attention Graph Pooling
Self-attention mask
本文使用图卷积来获得自注意分数:
$Z=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X \Theta_{a t t}\right) \quad\quad\quad\quad(3)$
其中:
-
- 自注意得分 $Z \in \mathbb{R}^{N \times 1}$;
- 邻接矩阵 $\tilde{A} \in \mathbb{R}^{N \times N}$;
- 注意力参数矩阵 $\Theta_{a t t} \in \mathbb{R}^{F \times 1}$;
- 特征矩阵 $X \in \mathbb{R}^{N \times F}$;
- 度矩阵 $\tilde{D} \in \mathbb{R}^{N \times N}$;
保留部分重要节点:
$\begin{array}{l} \mathrm{idx}=\operatorname{top}-\operatorname{rank}(Z,\lceil k N\rceil)\\Z_{\text {mask }}=Z_{\mathrm{idx}}\end{array} \quad\quad\quad\quad(4)$
基于自注意得分 $Z$ ,保留前 $ \lceil k N\rceil$ 个节点,其中 $k \in(0,1]$ 代表着池化率,$Z_{\text{mask}}$ 是 feature attention mask。
Graph pooling
获得新特征矩阵和邻接矩阵:
$\begin{array}{l} X^{\prime}=X_{\mathrm{idx},:}\\X_{\text {out }}=X^{\prime} \odot Z_{\text {mask }}\\A_{\text {out }}=A_{\mathrm{idx}, \mathrm{idx}}\end{array} \quad\quad\quad\quad(5)$
其中,$\odot$ 哈达玛积。
Variation of SAGPool
$Z=\sigma(\operatorname{GNN}(X, A)) \quad\quad\quad\quad(6)$
$Z=\sigma\left(\operatorname{GNN}\left(X, A+A^{2}\right)\right) \quad\quad\quad\quad(7)$
$Z=\sigma\left(\mathrm{GNN}_{2}\left(\sigma\left(\mathrm{GNN}_{1}(X, A)\right), A\right)\right) \quad\quad\quad\quad(8)$
$Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right) \quad\quad\quad\quad(9)$
3.2 Model Architecture
本节用来验证模块的有效性。
Convolution layer
图卷积 GCN:
$h^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} \Theta\right) \quad\quad\quad\quad(10)$
与 $\text{Eq.3}$ 不同的是,$\Theta \in \mathbb{R}^{F \times F^{\prime}}$ 。
Readout layer
根据 JK-net architecture 的思想:
$s=\frac{1}{N} \sum_{i=1}^{N} x_{i} \| \max _{i=1}^{N} x_{i} \quad\quad\quad\quad(11)$
其中:
- $N$ 代表着节点的个数;
- $x_{i}$ 代表着第 $i$ 个节点的特征向量;
代码:
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)
即:平均池化和 最大池化进行拼接。
Global pooling architecture & Hierarchical pooling architecture
对比如下:
Model Code:
class Net(torch.nn.Module):
def __init__(self, args):
super(Net, self).__init__()
self.args = args
self.num_features = args.num_features
self.nhid = args.nhid
self.num_classes = args.num_classes
self.pooling_ratio = args.pooling_ratio
self.dropout_ratio = args.dropout_ratio
self.conv1 = GCNConv(self.num_features, self.nhid)
self.pool1 = SAGPool(self.nhid, ratio=self.pooling_ratio)
self.conv2 = GCNConv(self.nhid, self.nhid)
self.pool2 = SAGPool(self.nhid, ratio=self.pooling_ratio)
self.conv3 = GCNConv(self.nhid, self.nhid)
self.pool3 = SAGPool(self.nhid, ratio=self.pooling_ratio)
self.lin1 = torch.nn.Linear(self.nhid * 2, self.nhid)
self.lin2 = torch.nn.Linear(self.nhid, self.nhid // 2)
self.lin3 = torch.nn.Linear(self.nhid // 2, self.num_classes)
def forward(self, data):
# 读取每个 batch 中的图数据
x, edge_index, batch = data.x, data.edge_index, data.batch
# 第一次做 Self-Attention Graph Pooling=======
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
# 第一次 Readout layer
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# 第二次做 Self-Attention Graph Pooling=======
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
# 第二次 Readout layer
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# 第三次做 Self-Attention Graph Pooling=======
x = F.relu(self.conv3(x, edge_index))
x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
# 第三次 Readout layer
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# 跳跃连接
x = x1 + x2 + x3
# MLP
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.lin2(x))
x = F.log_softmax(self.lin3(x), dim=-1)
return x
SAGPool Code:
class SAGPool(torch.nn.Module):
def __init__(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh):
super(SAGPool,self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.score_layer = Conv(in_channels,1)
self.non_linearity = non_linearity
def forward(self, x, edge_index, edge_attr=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
#x = x.unsqueeze(-1) if x.dim() == 1 else x
score = self.score_layer(x,edge_index).squeeze()
perm = topk(score, self.ratio, batch)
x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
batch = batch[perm]
edge_index, edge_attr = filter_adj(
edge_index, edge_attr, perm, num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm
4 Experiments
数据集
基线实验
SAGPool 的变体
5 Conclusion
本文提出了一种基于自注意的SAGPool图池化方法。我们的方法具有以下特征:分层池、同时考虑节点特征和图拓扑、合理的复杂度和端到端表示学习。SAGPool使用一致数量的参数,而不管输入图的大小如何。我们工作的扩展可能包括使用可学习的池化比率来获得每个图的最优聚类大小,并研究每个池化层中多个注意掩模的影响,其中最终的表示可以通过聚合不同的层次表示来获得。
修改历史
2022-05-08 创建文章
2022-08-12 修订文章
因上求缘,果上努力~~~~ 作者:图神经网络,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/16230073.html