GraphSAGE:如何将图神经网络扩展到数十亿个连接

GraphSAGE:如何将图神经网络扩展到数十亿个连接

UberEats 有什么共同点? **** 和Pinterest?他们的推荐系统由 GraphSAGE 提供支持,具有数百万和数十亿个节点和边缘。

  • Pinterest 开发了自己的版本,称为 PinSAGE,向用户推荐最相关的图像(Pins)。该资源的图表包含 180 亿个连接和 30 亿个节点。
  • 优食 还使用 GraphSAGE 的修改版本来推荐餐点、餐馆和美食。该平台声称支持超过 600,000 家餐厅和 6600 万用户。

在本教程中,由于 Google Colab 的限制,我们使用的数据集包含 2 万个节点,而不是数十亿个节点。在学习的过程中,我们会坚持原有的架构 GraphSAGE ,以及从以前的选项中触及一些有趣的功能。

代码可以用这个记事本运行 谷歌公司 .

1. PubMed 数据集

t-SNE график PubMed. Изображение автора

考研 是数据集的一部分 小行星 **** (麻省理工学院许可证)。这是您需要了解的内容。

  • 它包含来自 PubMed 数据库的 19,717 篇糖尿病研究论文。
  • 节点的特征是 500 维的 TF-IDF 加权词向量,这是一种在没有转换器的情况下总结文档的相当方便的方法。
  • 该任务被简化为三类:实验性糖尿病、1 型糖尿病和 2 型糖尿病。

目标是达到 70% 的准确率。

 从 torch_geometric.datasets 导入 Planetoid  
  
 数据集 = Planetoid(root='.', name="Pubmed")  
 数据 = 数据集[0]  
  
 # 输出数据集的信息  
 print(f'数据集:{数据集}')  
 打印(' -  -  -  -  -  -  -  -  - -'print(f'图数:{len(dataset)}')  
 print(f'节点数:{data.x.shape[0]}')  
 print(f'特征数量:{dataset.num_features}')  
 print(f'类数:{dataset.num_classes}')  
  
 # 输出关于图的信息  
 打印(f'\n图表:')  
 打印(' -  -  - 'print(f'训练节点:{sum(data.train_mask).item()}')  
 print(f'评估节点:{sum(data.val_mask).item()}')  
 print(f'测试节点:{sum(data.test_mask).item()}')  
 print(f'边缘是定向的:{data.is_directed()}')  
 print(f'Graph 有孤立的节点:{data.has_isolated_nodes()}')  
 print(f'Graph 有循环:{data.has_self_loops()}')


 数据集:Pubmed()  
 ------------------  
 图数:1  
 节点数:19717  
 功能数量:500  
 班数:3 图形:  
 ------  
 训练节点: **60**  
 评估节点: **500**  
 测试节点: **1000**  
 边缘是有向的:假  
 图有孤立的节点:假  
 图表有循环:假

与完整图相比,PubMed 的训练节点数量非常少——它只需要 60 个样本就可以学习如何对 1000 个测试节点进行分类。

尽管存在挑战,GGNs 仍然能够实现高水平的准确性。这是已知方法中的排行榜,根据 带代码的论文

使用此特定设置(60 个训练节点和 1000 个测试节点),我无法在 PubMed 上找到 GraphSAGE 的结果,所以不要指望高精度。但是在处理大图时,另一个指标可能会变得同样重要:训练时间。

2.理论上的GraphSAGE

Изображение автора

Изображение автора

GraphSAGE 算法可以分为两个步骤。

  1. 邻居采样。
  2. 聚合。

邻居抽样

Mini-bunching 是机器学习中常用的技术。它将数据集拆分为更小的批次,从而使您可以更有效地训练模型。以下是这种技术的一些好处。

  1. 提高了准确性。 小批量有助于减少破坏(梯度被平均)以及错误率的变化。
  2. 速度提高。 小批量是并行处理的,比大批量需要更少的训练时间。
  3. 改进的缩放。 整个数据集可以超过 GPU 内存量,但小批量可以绕过这个限制。

mini-bunching 技术非常方便,已成为传统神经网络中工作的标准。然而,对于图形数据,事情就不那么简单了,因为将数据集分成小块会导致节点之间的重要链接断开。

该怎么办?近年来,已经开发了几种策略来创建图的迷你包,包括 邻居抽样 .还有其他技术可以在 PyG 文档 ,例如子图聚类。

Выборка по соседям. Изображение автора

邻居抽样技术只考虑固定数量的随机邻居。这是这个过程的样子。

  1. 我们确定邻居的数量(1 个转换),这些邻居的邻居数量(2 个转换),等等。
  2. 选择器查看节点的邻居列表,这些邻居的邻居等等,然后随机选择预定数量的他们。
  3. 选择器输出包含目标节点和随机选择的相邻节点的子图。

对列表中的每个节点或整个图形重复此过程。但是为每个节点创建一个子图效率很低,相反我们可以批量处理它们。在这种情况下,每个子图都被多个目标节点使用。

邻居抽样还有另一个优点。一些节点非常受欢迎并充当枢纽,例如社交媒体名人。从计算的角度来看,获取这些节点的隐藏向量可能非常昂贵,因为它需要计算数千甚至数百万个邻居的隐藏向量。 GraphSAGE 通过忽略大多数节点来纠正这种情况。

在 PyG 中,邻居采样是通过一个对象来实现的[ 邻居加载器](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader) .假设我们需要 5 个邻居和这 5 个邻居中的 10 个 ( num_neighbors )。如前所述,我们可以定义 批量大小 通过为多个目标节点创建子图来加速该过程。

 从 torch_geometric.loader 导入 NeighborLoader  
 从 torch_geometric.utils 导入 to_networkx  
  
 # 使用邻居采样创建包  
 train_loader = NeighborLoader(  
 数据,  
 num_neighbors=[5, 10],  
 批量大小=16,  
 input_nodes=data.train_mask,  
 )  
  
 # 输出每个子图  
 对于 i,枚举中的子图(train_loader):  
 print(f'子图 {i}: {子图}')  
  
 # 构建每个子图  
 fig = plt.figure(figsize=(16,16))  
 对于 idx, (subdata, pos) in enumerate(zip(train_loader, ['221', '222', '223', '224'])):  
 G = to_networkx(子数据,to_undirected=True)  
 ax = fig.add_subplot(pos)  
 ax.set_title(f'子图 {idx}')  
 plt.axis('关闭')  
 nx.draw_networkx(G,  
 pos=nx.spring_layout(G, 种子=0),  
 with_labels=真,  
 节点大小=200,  
 node_color=subdata.y,  
 cmap="酷",  
 字体大小=10  
 )  
 plt.show()


**子图 0** : 数据(x=[389, 500], edge_index=[2, 448], batch_size=16)  
 **子图1** : 数据(x=[264, 500], edge_index=[2, 314], batch_size=16)  
 **子图2** : 数据(x=[283, 500], edge_index=[2, 330], batch_size=16)  
 **子图 3** : 数据(x=[189, 500], edge_index=[2, 229], batch_size=12)

我们创建了 4 个不同大小的子图,保证了它们的并行处理并符合 GPU 的计算资源。

邻居的数量是一个重要的指标,因为截断一个图会删除很多信息,我们可以通过查看节点度(邻居的数量)来看到:

 从 torch_geometric.utils 导入学位  
 从集合导入计数器  
  
 def plot_degree(数据):  
 # 获取每个节点的度数列表  
 度数=度数(data.edge_index[0]).numpy()  
  
 # 计算每个度数的节点数  
 数字 = 计数器(度)  
  
 # 构建条形图  
 无花果,斧头 = plt.subplots(figsize=(18, 6))  
 ax.set_xlabel('节点度')  
 ax.set_ylabel('节点数')  
 plt.bar(numbers.keys(),  
 数字.值(),  
 颜色='#0A047A')  
  
 # 绘制原图节点的度数图  
 plot_degree(数据)  
  
 # 绘制最终子图节点的度数图  
 plot_degree(子数据)

Степень узлов в начальном графе

Степень узлов после выборки по соседям

在此示例中,子图节点的最大度数为 5,远低于初始最大值。在使用 GraphSAGE 时,记住这种权衡非常重要。

PinSAGE 使用不同的采样方法,即随机游走方法,它有两个主要功能。

  1. 选择一定数量的邻居(类似于 GraphSAGE)。
  2. 获取它们的相对重要性(重要节点比其他节点更频繁地出现)。

这种策略有点像快速注意机制。它为节点分配权重并增加最受欢迎的节点的相关性。

聚合

聚合过程确定如何组合特征向量以获得节点嵌入。原始文档提供了三种聚合特征的方法:

  • 平均值聚合器;
  • LSTM-聚合器;
  • 子样本聚合器。

Агрегация (изображение автора)

均值聚合器是所有聚合器中最简单的。操作方法与 GCN 方法类似。

  1. 目标节点的隐藏特征 H ᵥ 和他的邻居 H ᵤ 合并。
  2. 最终向量被平均。
  3. 应用了权重矩阵 W 的线性变换。

然后可以将结果输入非线性激活函数 σ(例如 tanh 和 ReLU)。这是我们将在 PyG 中使用的技术,也是 UberEats 选择的技术。

选择 LSTM 聚合器似乎是一个奇怪的想法,因为它的架构是顺序的——它为杂乱无章的节点设置了顺序。因此,作者随机打乱它们以强制 LSTM 只考虑隐藏的特征。这种技术在比较测试中显示出最好的结果。

子样本聚合器将每个邻居的潜在向量馈送到前馈神经网络中。最大值二次采样操作应用于结果。

3. GraphSAGE 与 PyTorch 几何

我们可以使用层轻松地将 GraphSAGE 架构嵌入到 PyTorch Geometric 中 SAGEConv .此实现与文档中的不太相同,因为它使用 2 个矩阵而不是一个:

创建一个有两层的网络 SAGEConv .

  • 第一个将使用 ReLU 作为激活函数和过滤层。
  • 第二个将直接输出节点的附件。

由于我们正在处理分类为多个类别的问题,我们将使用交叉熵作为损失函数。

为了展示 GraphSAGE 的好处,让我们将其与没有采样的 GCN 和 GAT 进行比较:

 类 GraphSAGE(torch.nn.Module):  
 """GraphSAGE"""  
 def __init__(self, dim_in, dim_h, dim_out):  
 超级().__init__()  
 self.sage1 = SAGEConv(dim_in, dim_h)  
 self.sage2 = SAGEConv(dim_h, dim_out)  
 self.optimizer = torch.optim.Adam(self.parameters(),  
 lr=0.01,  
 weight_decay=5e-4)  
  
 def forward(self, x, edge_index):  
 h = self.sage1(x, edge_index)  
 h = 火炬.relu(h)  
 h = F.dropout(h, p=0.5, training=self.training)  
 h = self.sage2(h, edge_index)  
 返回 h, F.log_softmax(h, dim=1)  
  
 def fit(self, data, epochs):  
 标准 = torch.nn.CrossEntropyLoss()  
 优化器 = self.optimizer  
  
 自我训练()  
 对于范围内的纪元(纪元+1):  
 累积 = 0  
 val_loss = 0  
 val_acc = 0  
  
 # 包训练  
 train_loader 中的批处理:  
 优化器.zero_grad()  
 _, out = self(batch.x, batch.edge_index)  
 损失=标准(out[batch.train_mask],batch.y[batch.train_mask])  
 acc += 准确度(out[batch.train_mask].argmax(dim=1),  
 batch.y[batch.train_mask])  
 loss.backward()  
 优化器.step()  
  
 # 合规确认  
 val_loss += 标准(out[batch.val_mask],batch.y[batch.val_mask])  
 val_acc += 准确度(out[batch.val_mask].argmax(dim=1),  
 batch.y[batch.val_mask])  
  
 # 每 10 个 epoch 输出指标  
 如果(纪元 % 10 == 0):  
 print(f'Epoch {epoch:>3} | 训练损失:{loss/len(train_loader):.3f} '  
 f'|训练加速:{acc/len(train_loader)*100:>6.2f}% |价值损失:'  
 f'{val_loss/len(train_loader):.2f} |价值累计:'  
 f'{val_acc/len(train_loader)*100:.2f}%')

在 GraphSAGE 中,我们查看由邻居采样过程生成的数据包(4 个子图)。正因为如此,计算准确性和验证损失的方式也不同。

以下是 GCN、GAT 和 GraphSAGE 的结果(在准确度和训练时间方面):

**全球网络** 测试精度: **78.40%(52.6 秒)  
 盖特** 测试精度: **77.10%(18分7秒)  
 GraphSAGE** 测试精度: **77.20%(12.4 秒)**

在准确性方面,这三个模型具有相似的结果。我们希望 GAT 性能更好,因为它的聚合机制更详细,但情况并非总是如此。

真正的区别在于训练时间:在这种情况下,GraphSAGE 比 GAT 快 88 倍,比 GCN 快 4 倍。

这就是 GraphSage 的真正威力。通过使用邻居采样截断图,我们丢失了很多信息。最新的节点附件可能不如使用 GCN 和 GAT 时那么好。但是,GraphSage 旨在提高可伸缩性。反过来,它可以导致构建大尺寸图以获得更好的准确性。

Изображение автора

这项工作是使用监督学习(节点分类)完成的,但 GraphSAGE 也可以在无监督的情况下进行训练。

在这种情况下,不能使用交叉熵损失。我们需要开发一个损失函数,使原始图中相邻的节点在嵌套空间中彼此靠近。相反,同一个函数必须保证图的远节点在嵌套空间中的距离相同。这些损失在用于使用的文档中进行了描述 GraphSage .

GraphSage 的修改,例如 PinSAGE 和 UberEats 使用的修改,都是针对推荐系统的。

他们的任务是对每个用户(pin、餐馆)最相关的元素进行排名,尽管它们之间存在显着差异。不仅要找到最近的投资,而且要尽可能准确地分配它们的重要性程度。这就是为什么这些系统也进行无监督训练,但使用不同的损失函数来测量输入数据点之间的相对距离。

结论

GraphSage 是一种用于处理大型图的非常快速的架构。它可能不如 GCN 和 GAT 准确,但在处理大量数据时它的使用很重要。 GraphSage 的高速是通过对用于图细化和快速聚合的邻居采样的深思熟虑的组合来实现的。在此示例中,使用了均值聚合器。

在本文中,我们做了以下工作。

  • 使用 PubMed 检查了一个新数据集。
  • 我们分析了邻居采样方法的操作原理,该方法在每次转换中考虑了预定数量的邻居。
  • 我们回顾了 GraphSage 文档中提供的三个聚合器,并专注于平均值聚合器。
  • 我们测试了三个模型(GraphSAGE、GAT 和 GCN)的准确性和训练时间。

另请阅读:

阅读我们 电报 , VK Yandex.Zen

文章翻译 马克西姆·拉邦 : **** GraphSAGE:将图神经网络扩展到数十亿个连接

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明

本文链接:https://www.qanswer.top/37138/07401710

posted @   哈哈哈来了啊啊啊  阅读(306)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示