笔记:GAT入门学习
GAT图注意力网络
GAT 采用了 Attention 机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于 inductive 任务。
假设 Graph 包含 $N$ 个节点,每个节点的特征向量为 $h_i$,维度是 $F$,如下所示:
\begin{gathered}
\boldsymbol{h}=\left\{h_{1}, h_{2}, \ldots, h_{N}\right\} \\
h_{1} \in R^{F}
\end{gathered}
节点 $j$ 是节点 $i$ 的邻居,则可以使用 Attention 机制计算节点 $j$ 对于节点 $i$ 的重要性,即 Attention Score:
\begin{gathered}
e_{i j}=\operatorname{Attention}\left(W h_{i}, W h_{j}\right) \\
\alpha_{i j}=\operatorname{Softmax}_{j}\left(e_{i j}\right)=\frac{\exp \left(e_{i j}\right)}{\sum_{k \in N_{i}} \exp \left(e_{i k}\right)}
\end{gathered}
注意这个 $w$ 都是同一个
GAT 具体的 Attention 做法如下,把节点 $i、j$ 的特征向量 $h'_i$、$h'_j$ 拼接在一起,然后和一个 $2F'$ 维的向量 $a$ 计算内积。激活函数采用 LeakyReLU,公式如下:
$$
\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{j}\right]\right)\right)}{\sum_{k \in N_{i}} \exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{k}\right]\right)\right)}
$$
|| 表示拼接操作
经过 Attention 之后节点 $i$ 的特征向量如下:
$$h_{i}^{\prime}=\sigma\left(\sum_{j \in N_{i}} \alpha_{i j} W h_{j}\right)$$
GAT 也可以采用 Multi-Head Attention,如果有 K 个 Attention,则需要把 K 个 Attention 生成的向量拼接在一起,如下:
$$h_{i}^{\prime}=\operatorname{concat}\left(\sigma\left(\sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)\right)$$
但是如果是最后一层,则 K 个 Attention 的输出不进行拼接,而是求平均:
$$h_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)$$
网络结构:
样例来自 https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py
class GAT(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(GAT, self).__init__()
# num_features: Alias for num_node_features.
self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
# On the Pubmed dataset, use heads=8 in conv2.
self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False,
dropout=0.6)
def forward(self, x, edge_index):
ipdb.set_trace()
x_copy = x.clone()
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x + x_copy # Residual connection, 避免孤立节点变成全0
# return F.log_softmax(x, dim=-1) # log_softmax ??
return x # 我觉得这个位置还不要softmax
参考链接:https://ai.baidu.com/forum/topic/show/972764