学习笔记:SSTBAN 用于长期交通预测的自监督时空瓶颈注意力

Self-Supervised Spatial-Temporal Bottleneck Attentive Network for Efficient Long-term Traffic Forecasting
用于高效长期交通预测的自监督时空瓶颈注意力网络
期刊会议:ICDE2023
论文地址:https://ieeexplore.ieee.org/document/10184658
代码地址:https://github.com/guoshnBJTU/SSTBAN

长期交通预测存在的问题:

  • 难以平衡准确率和效率。随着时间跨度增大,要么无法捕捉长期动态性,要么以二次计算复杂度为代价获取全局接受域。
  • 高质量的训练数据需求与模型的泛化能力的矛盾。如何提升数据的利用效率值得思考。
    SSTBAN采用多任务框架,结合自监督学习器对历史交通数据产生鲁棒的潜在表示,从而提高其泛化性能和预测的鲁棒性。此外,作者还设计了一个时空瓶颈注意机制,在编码全局时空动态的同时降低了计算复杂度。

长期预测需求分析:
与有助于及时决策的短期预测相比,长期预测为旅行者和管理员提供了必要的支持信息,以优化旅行计划和运输资源管理。特别是未来几个小时的流量预测信息,有助于用户提前制定路由计划。

选择注意力机制的原因:
目前STGNNs分为RNN-based,CNN-based和attention-based方法。RNN-based存在梯度消失问题,不利于长期预测,且序列顺序的预测方式使得模型训练时间随着预测时间线性增加;CNN-based的kernel大小限制了长期动态性的捕捉能力;Attention-based更灵活,不会受到空间和时间距离的影响,但是存在二次计算复杂度的问题,这是要解决的问题。

数据的利用效率:高质量数据需求与泛化性矛盾
现有方法普遍有较强的高质量数据需求,当训练数据存在噪声时,就会导致过拟合或是学习到虚假的关系,泛化能力不佳。于是引入了常用于NLP和CV中的自监督学习。这要正确认知NLP/CV和时空交通预测的区别:

  • NLP/CV中的基础模式,如形状和语义,在广泛数据集中是通用的;而交通数据集中则鲜有这样的共同特征,比如可能数据的特征都不一样。
  • NLP只需捕捉序列特征——即时间,CV只需要捕捉空间特征,但是STGNN要同时捕捉时空特征。

贡献:

  • 第一次提出了一种采用自监督学习器的时空交通预测模型,满足了泛化和鲁棒需求。
  • 设计了一种时空瓶颈注意力机制,能够高效捕捉长期时空动态,将时间复杂度由二次方降低至线性。(RNN:?)
  • 在九个数据集上进行了实验,证明了在精度和效率上的优势。

模型

图: SSTBAN架构

模型包含两个分支:第一个是时空预测分支,第二个是时空自监督学习分支,因此是个多任务框架。

在训练阶段,两个分支一起工作。在分支一中,原始的数据依次经过ST Encoder、Transformer Attention,最后由ST Forecasting Decoder预测;在分支二中,首先随机mask掉一些数据,将破损的数据经过ST Encoder来用剩余的数据提取特征,经过ST Reconstruction Decoder来补全丢掉的数据,并将补全的数据和分支一的完整数据进行对齐比较(为了避免噪声的影响,这里的比较是放在了潜在空间中的)。训练损失也包含了两个,一个是预测误差的MAE,另一个是对齐的MSE。

在两个分支中,encoder和两个decoder由一样的时空瓶颈注意力模块(STBA)和时空嵌入模块(STE)构成。

STBA目的是捕捉长期的时空动态性,且维持低的计算复杂度。
STE目的是提取不同时间切片和节点的独特性,来弥补基于注意力机制的STBA对顺序的不敏感缺点。我们通过端到端的方式训练空间嵌入\(E_{SP}\in R^{N\times d}\),它在所有时间中共享;通过time-of-day和day-of-week,用one-hot和MLPs得到输入时间嵌入\(E_{TP}\in R^{P\times d}\)和输出时间嵌入\(E_{TP}'\in R^{Q\times d}\),它在所有节点中共享;将它们相加得到输入序列嵌入\(\mathcal{E}\in R^{P\times N\times d}\)和输入序列嵌入\(\mathcal{E}'\in R^{Q\times N\times d}\)

时空瓶颈注意力 STBA

图: 时空瓶颈注意力STBA,时间瓶颈注意力TBA,空间SBA

图中\(\mathcal{Z}^{(l-1)}=(\mathcal{H}^{(l-1)}||\mathcal{E})\in\mathbb{R}^{P\times N\times2d}\)

STBA包含了空间注意力(SBA)和时间注意力(TBA)。它们并没有直接和其他点相连,而是和参考点相连,而参考点的数量远小于时间点和空间点。我们还希望参考点能够编码通用的全局信息。由于整体形状像瓶颈而得名。

TBA平行地处理每个点的输入。(这里的过程还不是很懂。)

STBA具有以下特点:

  • 由于参考点的设置,运算复杂度从\(O(N^2)\)降低到了\(O(NN')\),因为\(N'\)是个小的超参数。相比于GCN,STBA不需要预定义的图结构,同时能动态调整节点间关系强度。
  • 参考点起到了编码全局模式的作用,可以理解输入,如用来聚类。

时空预测分支

图:分支一 时空预测分支结构

组成部分:
(1)时空编码器:由时空瓶颈注意力组成,映射到潜在表征空间
(2)Transformer attention:将潜在空间下的历史信号适配于预测信号尺寸。为了缓解长期预测存在的比较严重的误差传播问题,我们通过自适应地融合历史中的不同特征,用注意力机制直接把每步的历史信号和预测信号连接起来。即

\[\mathcal{H^{\prime}}_{:,v}^{(0)}=\mathrm{MHSA}(\mathcal{E}_{:,v}^{\prime},\mathcal{E}_{:,v},\mathcal{H}_{:,v}^{(L)})\in\mathbb{R}^{Q\times d} \]

(3)时空预测解码器:由若干层时空瓶颈注意力,最后加上全连接层组成。

时空自监督学习分支

这一分支从mask掉部分信号的不完整数据中,理解时空关系,并在潜在空间中重构缺失的信号。目的是训练潜在空间表征能力。包括如下部分:
(1)Masking:考虑到mask掉单独一个时间点的数据,很容易通过前后数据算出来,因此在时间或空间维度上mask掉连续的段,以此来学习趋势模式。Mask策略是,将输入数据分成若干patch,并将一定比例的patch全部清0。

图:Masking算法

(2)时空编码器:和分支一中的一样。只是,被mask掉的数据不参与时空瓶颈注意力的计算。
(3)时空重构编码器:输入(2)提供的残缺潜在表征,以及指示mask位置的token向量。由若干时空瓶颈组成,并将重构后的表征与分支一的完整表征匹配。

实验

数据集
分别是Seattle Loop,PEMS04, PEMS08
这个Seattle Loop我还是第一次见。

Loop Seattle 数据集由部署在西雅图地区高速公路(I-5、I-405、I-90和SR-520)上的感应环路探测器收集,包含来自323个传感器站的交通状态数据。

图:数据集信息

超参数表
\(L\):ST Encoder中STBA的数量
\(L'\):STF Decoder中STBA的数量
\(d\):多头注意力机制的维数
\(h\):多头注意力机制的头数
\(l_m\):Masking过程的patch length
\(\alpha_m\):Masking过程的mask率
\(\lambda\):预测损失和对齐损失的权重。越大代表对齐损失占比越大。
时间和空间参考点的数量都是3。

图:超参数设置

对比试验

实际上作者选的这些基线模型都比较老,说服力比较差。我在本文的最后放上了自己做的一点对比实验,可以当作参考。

图:PEMS对比试验
图:SeattleLoop对比实验

随着时间跨度增加,SSTBAN的优势也在增加。

图:预测表现与预测长度的关系

鲁棒测试

作者还进行了以下两个实验。可以看到,模型在这两个方面还是有优点的。遗憾就是,对比的模型只有两个,且比较古老,依然是缺乏说服力。
注:GMAN AAAI2020,DMSTGCN KDD2021

图:减少训练数据
图:随机添加噪声

消融实验

将STBA与普通注意力网络进行对比。也觉是说,STBA在减小时间复杂度的同时,还能增加准确率。

图:消融实验

算力消耗

可以看出,时间消耗和空间小号还是比较小的。不过在实验中,模型的时间和空间占用与batch size等参数设置有关,所以这个只能做参考吧。在我用PEMS08做复现,预测48步时,设置batch size=8,显存占用20G左右。

图:算力实验

复现

以下是我的复现结果。以下实验每个仅做了一次,并没有重复实验,所以仅作参考。

batch size对时间的影响较大。越大,时间越短,但相应的占用显存越多。在SSTBAN用PEMS08预测48步的实验中,设置batch size=16时,40G显存的A100就已经跑不动了,可以看出SSTBAN的空间占用还是比较大的。SSTBAN的特点是,训练一个epoch耗时比较长,但是epoch数量少,就触发早停了。

注:TrendGCN属于CIKM2023

模型 数据集 步数 epoch MAE MAPE RMSE 时间
TrendGCN 08 12,12 120(batch64) 15.11 9.68 24.25 0h41
TrendGCN 08 24,24 120(batch128) 16.84 10.77 27.14 0h40
TrendGCN 08 36,36 120(batch64) 17.70 11.95 28.63 1h30
TrendGCN 08 48,48 120(batch128) 18.86 12.91 29.97 1h20
模型 数据集 步数 epoch MAE MAPE RMSE 时间
SSTBAN 08 12,12 71(batch32) 15.36 10.79 24.26 0h50
SSTBAN 08 24,24 16(batch32) 15.40 10.68 26.20 1h30
SSTBAN 08 36,36 18(batch4) 16.56 11.65 29.33 3h30
SSTBAN 08 48,48 15(batch8) 17.29 15.10 29.04 2h40
SSTBAN 04 36,36 15(batch8) 21.09 15.29 37.42 2h
posted @ 2024-03-23 10:33  white514  阅读(281)  评论(0编辑  收藏  举报