学习笔记:STAEformer
Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting
时空自适应嵌入让Transforer成为交通预测目前的最优模型
会议:CIKM2023
作者:Hangchen Liu, Zheng Dong, Renhe Jiang, Jiewen Deng, Jinliang Deng, Quanjun Chen, Xuan Song
论文地址:https://arxiv.org/abs/2308.10425v5
代码地址:https://github.com/XDZhelheim/STAEformer.git
复现
GPU: A100
本代码自带PEMS数据集,经过验证,四个数据集流量数据没有问题。
指标还不错,就是速度太慢了
数据集 | 节点数量 | MAE | MAPE | RMSE | epoch | 时间消耗(h,min) |
---|---|---|---|---|---|---|
PEMS03 | 358 | 15.69 | 15.47 | 28.62 | 58 | 3h30 |
PEMS04 | 307 | 18.23 | 12.04 | 30.30 | 51 | 1h10 |
PEMS07 | 883 | 18.81 | 8.22 | 32.23 | 72 | 15h30 |
PEMS08 | 170 | 13.58 | 8.89 | 23.43 | 88 | 1h30 |
模型
模型图 |
嵌入\(Z\)
隐藏时空表征\(Z=E_f||E_p||E_a\in\mathbb R^{T\times N\times d_h}\),其中\(d_h=3d_f+d_a\)
具体构成如下
特征嵌入
即\(E_f\in\mathbb R^{T\times N\times d_F}\)
作者认为,全连接能够保留原本的信息:
周期嵌入
即\(E_p\in\mathbb R^{T\times N\times 2d_f}\)
有两个周期,一个是星期数(day-of-week)\(T_w\in\mathbb R^{N_w\times d_f}\),其中\(N_w=7\);
另一个是一天中的时间戳数(timestamp-of-day)\(T_d\in\mathbb R^{N_d\times d_f}\),其中\(N_d=288\)。
对应的,还有每一个时间戳相应的标记\(W^t\in\mathbb R^T\)和\(D^t\in\mathbb R^T\)
将二者拼合(究竟拼合什么),得到\(E_p\)
时空自适应嵌入
即\(E_a\in\mathbb R^{T\times N\times d_a}\)
作者认为时间关系受周期和时间序列顺序影响,且不同传感器的时间序列具有不同的时间模式,因此没有使用预定义的邻接矩阵,而是设计了一个自适应嵌入\(E_a\)
Transformer与预测部分
原文用"vanilla transformer",vanilla有“原生”的意思,因此指的是“对网络结构没有很大调整的transformer模型”
模型在时间和空间上分别应用transformer。
在时间transformer上,
计算自注意力分数:
最后得到输出\(Z^{(te)}\in\mathbb R^{T\times N\times d_h}\):
空间transformer同理,得到输出\(Z^{(sp)}\in\mathbb R^{T\times N\times d_h}\):
上述讲述transformer的过程省略了的归一化、残差连接、多头注意力机制等内容。最终对于输出\(Z'\in\mathbb R^{T\times N\times d_h}\),经过全连接层得到预测结果:
实验结果
对比试验
对比实验 |
- 本模型性能对超参数不敏感[1]
- \(d_f=24,\ d_a=80,\ L=3,\ heads=4,\ lr=0.001,\ batch\_size=16\)
消融实验
消融实验 |
- 实验表明自适应图很重要
Case Study
对比空间嵌入
用PDFormer中的空间嵌入替换时空自适应嵌入\(E_a\),并进行试验:
通过将输入沿时间轴进行变换,我们的时空自适应嵌入比空间嵌入表现出更大的性能下降,这表明它具有捕获时间信息的能力
好奇怪的验证方式
时空自适应嵌入可视化
可视化 |
左边使用t-SNE对空间维度进行可视化,结果表明不同的点形成簇。
t-SNE是什么? 一种降维方法
右边对范围为12的时间戳进行了相关系数计算,结果表明每一帧与附近的帧高度相关,随着时间的远离,相关性降低。这说明正确地模拟了时间序列相关性
总结
- 周期的用法我也早有思考
- 可视化可以学习一下
section4.1 "is not sensitive to the hyper-parameters" ↩︎