Abstract
我们提出了图注意网络(GATs),一种基于图结构数据的新型神经网络结构,利用masked self-attentional layers来解决先前基于图卷积或其近似的方法的缺点。通过堆叠节点能够关注其邻域特征的层,我们允许(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何昂贵的矩阵操作(如转置),也不依赖于预先了解图的结构。通过这种方式,我们同时解决了基于谱方法的图神经网络的几个关键挑战,并使我们的模型易于适用于归纳和转换问题。我们的GAT模型已经在四个已建立的转换和归纳图基准中取得或匹配了最先进的结果:Cora、Citeseer和Pubmed引文网络数据集,以及蛋白质相互作用数据集(其中测试图在训练过程中仍然不可见)。
Introduction
在之前的模型中,已经有很多基于神经网络的工作被用于处理图结构的数据。例如,最早的GNN网络可以被用于处理有环图、有向图或无向图。然而,GNN网络本身必须使整个网络达到不动点之后才可以进行计算。针对这一问题,通过将GRU引入到网络结构中,进一步提出了GGNN网络。后来,人们开始关注将卷积操作引入到图领域中,这一类算法可以被分为谱方法(spectral approaches)与非谱方法(non-spectral approaches)两大类。
谱方法是基于对图进行谱表示的一类方法。其上的卷积操作与图拉普拉斯矩阵的特征值分解有关,因此,往往需要进行密集的矩阵运算,而且整个计算并不是局部的。为了解决这一问题,GCN可以有效地对节点的一阶邻居进行处理,而且可以避免复杂的矩阵运算。然而,这些模型都依赖于图的结构,因此,在特定图结构上训练得到的模型往往不可以直接被使用到其他图结构上。
不同于谱方法,非谱方法是直接在图上(而不是在图的谱上)定义卷积。这类方法的一个挑战在于,如何定义一个可以处理可变大小邻居且共享参数的操作。针对这一问题,MoNet(mixture model CNN),可以有效地将CNN结构引入到图上。类似地,GraphSAGE模型使用一种归纳的方法来计算节点表示。具体来说,该模型首先从每个节点的邻节点中抽取出固定数量的节点,然后再使用特定的方式来融合这些邻节点的信息(如直接对这些节点的特征向量求平均,或者将其输入到一个RNN中),这一方法已经在很多大型归纳学习问题中取得了很好的效果。
在本文中,作者提出了一种基于attention的节点分类网络——GAT。其基本思想是,根据每个节点在其邻节点上的attention,来对节点表示进行更新。GAT具有以下几个特点:(1)计算速度快,可以在不同的节点上进行并行计算;(2)可以同时对拥有不同度的节点进行处理;(3)可以被直接用于解决归纳学习问题,即可以对从未见过的图结构进行处理。
Graph Attention Layer
GAT模型是堆叠多个GAT Layer实现的,了解了GAT Layer就相当于弄明白了GAT模型

输入和输出
Graph Attention Layer的输入是节点特征集合(列向量):
其中代表节点个数,。
输出是一组新的节点特征集合:
其中。
节点对其邻居注意力值的计算
注意力机制定义如下:
表示注意力机制计算出来的节点对于节点的重要程度。实际计算中只需要计算节点周围有连边的邻居的注意力分数,因此首先计算节点对所有结点的注意力得分然后再将那些非邻居的节点的注意力得分清零,文中称为masked self-attentional。
注意力分数是上面计算出的的归一化结果:
论文中的注意力机制是通过一个单层前馈神经网络实现的,激活函数采用,因此注意力分数的计算可重写为:
其中表示拼接操作,是一个单层神经网络。
节点信息的聚合
一旦节点相对于其邻居的注意力分数计算成功后,就可采用如下公式更新节点特征:
式中的和注意力得分公式中的是参数共享的。
多头注意力
前面讨论的是单注意力的情况,文中还介绍了使用多头注意力机制从不同的角度聚合信息,多头注意力机制可以类比卷积神经网络里面的多卷积核。用表示第个注意力计算出的注意力分数,将所有的特征拼接起来当作最终的特征:
因此如果采用多头注意力机制,那么得到的新特征的维度是原来特征维度的倍。对于非输出层不需要拼接,只需要将得到的个特征向量取平均即可。
与其他模型的比较
- GAT效率很高。相比于其他图模型,GAT无需使用特征值分解等复杂的矩阵运算。单层GAT的时间复杂度为
- 相比于GCN,每个节点的重要性可以是不同的,因此,GAT具有更强的表示能力。
- 对于图中的所有边,attention机制是共享的。因此GAT也是一种局部模型。也就是说,在使用GAT时,我们无需访问整个图,而只需要访问所关注节点的邻节点即可。这一特点的作用主要有:(1)可以处理有向图(若不存在,仅需忽略即可);(2)可以被直接用于进行归纳学习。
- 最新的归纳学习方法(GraphSAGE)通过从每个节点的邻居中抽取固定数量的节点,从而保证其计算的一致性。这意味着,在执行推断时,我们无法访问所有的邻居。然而,本文所提出的模型是建立在所有邻节点上的,而且无需假设任何节点顺序。
代码实现
代码实现参考了https://github.com/Diego999/pyGAT,关键地方进行了注释
文件一
GATLayer.py 定义了单层图注意力层
import torch
import torch.nn as nn
import torch.nn.functional as f
class GraphAttentionLayer(nn.Module):
def __init__(self, in_feature, out_feature, dropout, concat=True):
super(GraphAttentionLayer, self).__init__()
self.in_feature = in_feature
self.out_feature = out_feature
self.dropout = dropout
self.concat = concat
self.W = nn.Parameter(torch.empty([self.in_feature, self.out_feature]))
self.a = nn.Parameter(torch.empty([2 * self.out_feature, 1]))
nn.init.xavier_uniform_(self.W.data) # 或 self.W
nn.init.xavier_uniform_(self.a.data)
self.activate = nn.LeakyReLU(0.2)
def forward(self, h, adj):
"""
:param h: 初始节点特征 shape (N, in_feature)
:param adj: 邻接矩阵 shape (N, N)
:return:
"""
wh = torch.mm(h, self.W)
e = self._get_attention(wh)
e_zero = -9e15 * torch.ones_like(e) # 很小的数,目的是使 exp(x) -> 0,注意 exp(0) = 1
attention = torch.where(adj > 0, e, e_zero)
# (N, N) 两两节点之间的注意力分数
attention = torch.softmax(attention, dim=1)
# attention = f.dropout(attention, self.dropout) # 不 dropout 效果反而好一点,疑惑中......
h_prime = torch.mm(attention, wh)
if self.concat:
return f.elu(h_prime)
else:
return h_prime
def _get_attention(self, wh):
wh1 = torch.mm(wh, self.a[:self.out_feature, :])
wh2 = torch.mm(wh, self.a[self.out_feature:, :])
e = wh1 + wh2.T # (N, N) 两两节点之间的注意力
return self.activate(e)
def __repr__(self):
# self.__class__.__name__ 表示类名,即 GraphAttentionLayer
return self.__class__.__name__ + '(' + str(self.in_feature) + ' -> ' + str(self.out_feature) + ')'
文件二
import torch
import torch.nn as nn
import torch.nn.functional as f
from GATLayer import GraphAttentionLayer
class GAT(nn.Module):
def __init__(self, n_feat, n_hid, n_class, dropout, n_heads):
super(GAT, self).__init__()
self.dropout = dropout
# 计算两层
self.attentions = nn.ModuleList(
[GraphAttentionLayer(n_feat, n_hid, self.dropout, concat=True) for _ in range(n_heads)]
) # 第一层采用多头注意力机制
# 第二层采用单头注意力机制
self.out_att = GraphAttentionLayer(n_hid * n_heads, n_class, self.dropout, concat=False)
def forward(self, x, adj):
x = f.dropout(x, self.dropout)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = f.dropout(x, self.dropout)
x = f.elu(self.out_att(x, adj))
return f.log_softmax(x, dim=1)
if __name__ == "__main__":
model = GAT(100, 10, 7, 0.6, 3)
print(model)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)