学习笔记:STAEformer

Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting
时空自适应嵌入让Transforer成为交通预测目前的最优模型
会议:CIKM2023
作者:Hangchen LiuZheng DongRenhe JiangJiewen DengJinliang DengQuanjun ChenXuan Song
论文地址:https://arxiv.org/abs/2308.10425v5
代码地址:https://github.com/XDZhelheim/STAEformer.git

复现

GPU: A100
本代码自带PEMS数据集,经过验证,四个数据集流量数据没有问题。
指标还不错,就是速度太慢了

数据集 节点数量 MAE MAPE RMSE epoch 时间消耗(h,min)
PEMS03 358 15.69 15.47 28.62 58 3h30
PEMS04 307 18.23 12.04 30.30 51 1h10
PEMS07 883 18.81 8.22 32.23 72 15h30
PEMS08 170 13.58 8.89 23.43 88 1h30

模型

模型图

嵌入\(Z\)

隐藏时空表征\(Z=E_f||E_p||E_a\in\mathbb R^{T\times N\times d_h}\),其中\(d_h=3d_f+d_a\)
具体构成如下

特征嵌入

\(E_f\in\mathbb R^{T\times N\times d_F}\)

作者认为,全连接能够保留原本的信息:

\[E_f=FC(X_{t-T+1:t}) \]

周期嵌入

\(E_p\in\mathbb R^{T\times N\times 2d_f}\)

有两个周期,一个是星期数(day-of-week)\(T_w\in\mathbb R^{N_w\times d_f}\),其中\(N_w=7\)
另一个是一天中的时间戳数(timestamp-of-day)\(T_d\in\mathbb R^{N_d\times d_f}\),其中\(N_d=288\)
对应的,还有每一个时间戳相应的标记\(W^t\in\mathbb R^T\)\(D^t\in\mathbb R^T\)
将二者拼合(究竟拼合什么),得到\(E_p\)

时空自适应嵌入

\(E_a\in\mathbb R^{T\times N\times d_a}\)

作者认为时间关系受周期和时间序列顺序影响,且不同传感器的时间序列具有不同的时间模式,因此没有使用预定义的邻接矩阵,而是设计了一个自适应嵌入\(E_a\)

Transformer与预测部分

原文用"vanilla transformer",vanilla有“原生”的意思,因此指的是“对网络结构没有很大调整的transformer模型”

模型在时间和空间上分别应用transformer。

在时间transformer上,

\[Q^{(te)}=ZW_{Q}^{(te)},K^{(te)}=ZW_{K}^{(te)},V^{(te)}=ZW_{V}^{(te)} \]

计算自注意力分数:

\[A^{(te)}=Softmax\left(\frac{Q^{(te)}K^{(te)}}{\sqrt{d_{h}}}\right) \]

最后得到输出\(Z^{(te)}\in\mathbb R^{T\times N\times d_h}\)

\[Z^{(te)}=A^{(te)}V^{(te)} \]

空间transformer同理,得到输出\(Z^{(sp)}\in\mathbb R^{T\times N\times d_h}\)

\[Z^{(sp)}=SelfAttention(Z^{(te)}) \]

上述讲述transformer的过程省略了的归一化、残差连接、多头注意力机制等内容。最终对于输出\(Z'\in\mathbb R^{T\times N\times d_h}\),经过全连接层得到预测结果:

\[\hat Y=FC(Z') \]

实验结果

对比试验

对比实验
  • 本模型性能对超参数不敏感[1]
  • \(d_f=24,\ d_a=80,\ L=3,\ heads=4,\ lr=0.001,\ batch\_size=16\)

消融实验

消融实验
  • 实验表明自适应图很重要

Case Study

对比空间嵌入
用PDFormer中的空间嵌入替换时空自适应嵌入\(E_a\),并进行试验:
通过将输入沿时间轴进行变换,我们的时空自适应嵌入比空间嵌入表现出更大的性能下降,这表明它具有捕获时间信息的能力
好奇怪的验证方式

时空自适应嵌入可视化

可视化

左边使用t-SNE对空间维度进行可视化,结果表明不同的点形成簇。
t-SNE是什么? 一种降维方法

右边对范围为12的时间戳进行了相关系数计算,结果表明每一帧与附近的帧高度相关,随着时间的远离,相关性降低。这说明正确地模拟了时间序列相关性

总结

  • 周期的用法我也早有思考
  • 可视化可以学习一下

  1. section4.1 "is not sensitive to the hyper-parameters" ↩︎

posted @ 2023-10-20 10:59  white514  阅读(1110)  评论(4编辑  收藏  举报