PyTorch Geometric Temporal
一、简介
PyTorch Geometric Temporal是PyTorch Geometric的一个时间图神经网络扩展库。它建立在开源深度学习和图形处理库之上。PyTorch Geometric Temporal由最先进的深度学习和参数学习方法组成,用于处理时空信号。它是第一个用于几何结构的时间深度学习的开源库,并在动态和静态图上提供常量时差图神经网络。我们使用离散时间图快照(discrete time graph snapshots)来实现这一点。
实现的方法涵盖了广泛的数据挖掘(WWW, KDD),人工智能和机器学习(AAAI, ICONIP, ICLR)会议,研讨会,以及来自著名期刊的文章。
二、数据结构
PyTorch geometry Temporal被设计为提供易于使用的数据迭代器,这些迭代器作用于包含时间快照(Temporal snapshots)的时空数据上;这些迭代器服务于通过块对角化batch技巧将单个图或者多个图组合在一起的快照(snapshots)。
2.1 Temporal Signal Iterators 时间信号迭代器
PyTorch geometry Temporal为包含时间快照的时空数据集提供了数据迭代器。有三种类型的数据迭代器:
1) StaticGraphTemporalSignal——定义于静态图的时间信号;
2)DynamicGraphTemporalSignal——定义于动态图的时间信号;
3)DynamicGraphStaticSignal——定义于动态图的静态信号;
Temporal Data Snapshots
一个temporal data snapshot是 PyTorch GeometricData的对象;更多的细节可以查看这里,返回的时间快照具有以下属性:
edge_index: 边的索引;用于节点特征的构建;(哪些节点相连)
edge_attr: 边的属性——为变得索引进行了加权;
x: 顶点的特征;
y: 顶点的目标(target);
Temporal Signal Iterators with Batches
带有batch的时间信号迭代器;
PyTorch Geometric Temporal 提供了数据迭代器用于批量的时空数据集,这些数据集中包含了批量的时间快照;
有三种类型的批量时间迭代器:
和前文的三种类型的批量时间快照是对应的
StaticGraphTemporalSignalBatch ——批量静态图的时间信号;
DynamicGraphTemporalSignalBatch——批量动态图的时间信号;temporal signal
DynamicGraphStaticSignalBatch ——批量动态图的静态信号;static signals
Temporal Batch Snapshots
时间批量快照
这是PyTorch Geometric 的一个Batch对象;更多细节可以查看这里。
返回的时间批量快照具有下面的属性:
edge_index : LongTensor
edge_attr : FloatTensor
x : FloatTensor
y : FloatTensor 或者LongTensor
batch : batch的索引 LongTensor
三、基准数据集
我们发布了一些数据集用于对比时间图神经网络算法的表现;
相关的机器学习任务是节点级和图级监督学习。
整合的数据集:
匈牙利水痘数据集可以通过以下代码片段加载。公共get_dataset方法返回的数据集是一个StaticGraphTemporalSignal对象。
1 from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 2 3 loader = ChickenpoxDatasetLoader() 4 5 dataset = loader.get_dataset()
Spatiotemporal Signal Splitting
时空数据信号的分割
我们提供了用于创建数据迭代器临时分段的函数。这些函数返回train和test数据迭代器,这些迭代器使用固定的train-test比率分割原始迭代器。来自早期时间段的快照构成训练数据集,来自后期时间段的快照构成测试数据集。这样,时间预测就可以在类似现实生活的场景中进行评估。函数split_temporal_signal接受StaticGraphTemporalSignal或DynamicGraphTemporalSignal对象,并根据train_ratio指定的分割比率返回两个迭代器。
1 from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 2 from torch_geometric_temporal.signal import split_temporal_signal 3 4 loader = ChickenpoxDatasetLoader() 5 6 dataset = loader.get_dataset() 7 8 train_dataset, test_dataset = split_temporal_signal(dataset, train_ratio=0.8)
Applications
在下面,我们将概述两个案例研究,其中PyTorch几何时间可用于解决现实世界相关的机器学习问题。一个是关于流行病学预测另一个是关于网络流量预测。
Epidemiological Forecasting
我们在这个案例研究中使用的是匈牙利水痘病例数据集。我们将训练一个回归器,使用递归图卷积网络预测每个县每周报告的病例。首先,我们将加载数据集并创建一个适当的时空分割。
1 from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 2 from torch_geometric_temporal.signal import temporal_signal_split 3 4 loader = ChickenpoxDatasetLoader() 5 6 dataset = loader.get_dataset() 7 8 train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
在接下来的步骤中,我们将定义用于解决监督任务的递归图神经网络体系结构。构造函数定义了一个DCRNN层和一个前馈层。
值得注意的是,最终的非线性并没有集成到递归图卷积运算中。这个设计原则是一贯使用的,它取自 PyTorch Geometric。因此,我们手动定义了循环层和线性层之间的ReLU非线性。当我们解决零均值目标的回归问题时,最后的线性层没有跟随非线性。
1 import torch 2 import torch.nn.functional as F 3 from torch_geometric_temporal.nn.recurrent import DCRNN 4 5 class RecurrentGCN(torch.nn.Module): 6 def __init__(self, node_features): 7 super(RecurrentGCN, self).__init__() 8 self.recurrent = DCRNN(node_features, 32, 1) 9 self.linear = torch.nn.Linear(32, 1) 10 11 def forward(self, x, edge_index, edge_weight): 12 h = self.recurrent(x, edge_index, edge_weight) 13 h = F.relu(h) 14 h = self.linear(h) 15 return h
让我们定义一个模型(我们有4个节点特性),并在200个epoch的训练分割(前20%的时间快照)上训练它。当每个时间快照的损失累积时,我们反向传播。我们将使用学习率为0.01的Adam优化器。tqdm函数用于测量每个训练阶段的运行时需求。
1 from tqdm import tqdm 2 3 model = RecurrentGCN(node_features = 4) 4 5 optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 6 7 model.train() 8 9 for epoch in tqdm(range(200)): 10 cost = 0 11 for time, snapshot in enumerate(train_dataset): 12 y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 13 cost = cost + torch.mean((y_hat-snapshot.y)**2) 14 cost = cost / (time+1) 15 cost.backward() 16 optimizer.step() 17 optimizer.zero_grad()
使用测试的数据集,我们将评估训练过的递归图卷积网络的性能,并计算所有空间单位和时间周期的均方误差。
1 model.eval() 2 cost = 0 3 for time, snapshot in enumerate(test_dataset): 4 y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 5 cost = cost + torch.mean((y_hat-snapshot.y)**2) 6 cost = cost / (time+1) 7 cost = cost.item() 8 print("MSE: {:.4f}".format(cost))
Web Traffic Prediction
在这个案例研究中,我们使用的是Wikipedia Maths数据集。
我们将训练一个递归图神经网络,使用递归图卷积网络来预测维基百科页面的每日访问量。
首先,我们将加载数据集并使用14个滞后流量变量。接下来,我们创建一个适当的时空分割,使用50%天的数据来训练模型。
1 from torch_geometric_temporal.dataset import WikiMathsDatasetLoader 2 from torch_geometric_temporal.signal import temporal_signal_split 3 4 loader = WikiMathsDatasetLoader() 5 6 dataset = loader.get_dataset(lags=14) 7 8 train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)
在接下来的步骤中,我们将定义用于解决监督任务的递归图神经网络体系结构。构造函数定义了一个GConvGRU层和一个前馈层。需要再次注意的是,非线性并没有集成到递归图卷积运算中。卷积模型有固定数量的滤波器(可以参数化),并且考虑二阶邻域。
1 import torch 2 import torch.nn.functional as F 3 from torch_geometric_temporal.nn.recurrent import GConvGRU 4 5 class RecurrentGCN(torch.nn.Module): 6 def __init__(self, node_features, filters): 7 super(RecurrentGCN, self).__init__() 8 self.recurrent = GConvGRU(node_features, filters, 2) 9 self.linear = torch.nn.Linear(filters, 1) 10 11 def forward(self, x, edge_index, edge_weight): 12 h = self.recurrent(x, edge_index, edge_weight) 13 h = F.relu(h) 14 h = self.linear(h) 15 return h
让我们定义一个模型(我们有14个节点特征),并在50个epoch的训练分割(前50%的时间快照)上训练它。我们反向传播每个时间快照的损失。我们将使用学习率为0.01的Adam优化器。tqdm函数用于测量每个训练阶段的运行时需求。
1 from tqdm import tqdm 2 3 model = RecurrentGCN(node_features=14, filters=32) 4 5 optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 6 7 model.train() 8 9 for epoch in tqdm(range(50)): 10 for time, snapshot in enumerate(train_dataset): 11 y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 12 cost = torch.mean((y_hat-snapshot.y)**2) 13 cost.backward() 14 optimizer.step() 15 optimizer.zero_grad()
使用测试流量数据,我们将评估训练过的递归图卷积网络的性能,并计算所有网页和天数的均方误差。
1 model.eval() 2 cost = 0 3 for time, snapshot in enumerate(test_dataset): 4 y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 5 cost = cost + torch.mean((y_hat-snapshot.y)**2) 6 cost = cost / (time+1) 7 cost = cost.item() 8 print("MSE: {:.4f}".format(cost)) 9 >>> MSE: 0.7760