【论文阅读】DSTAGNN Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting
DSTAGNN Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting
Info
- title: DSTAGNN: Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting
- publish: ICML 2022
- url: dblp: DSTAGNN: Dynamic Spatial-Temporal Aware Graph Neural Network for Traffic Flow Forecasting. (uni-trier.de)
- corresponding author: ShiYong Lan
- first author:
- ShiYong Lan
- Yitong Ma
Model
Network Structure
Details
0. 前置知识
节点:\(V\)
节点数量:\(N\)
节点间连接(边):\(E\)
图:\(\mathcal{G}=(V,E)\)
邻接矩阵:\(A\in\mathbb{R}^{N\times N}\)
t时刻的图信号:\(X^t\in\mathbb{R}^{N\times C_p}\),其中\(C_p\)表示交通参数的类型(如交通流量、速度),本论文只有一个交通参数,即\(C_p=1\)
数据记录:\(X^{(t-M+1):t}\in\mathbb{R}^{N\times C_p\times M}\)
预测目标:\(X^{(t+1):(t+T)} \in \mathbb{R}^{N \times C_{p} \times T}\)
模型\(\mathcal{F}\):\(X^{(t+1):(t+T)}=\mathcal{F}\left[X^{(t-M+1): t} ; \mathcal{G}\right]\)
1. 节点特征的表示与转换
Take the traffic flow \(X^f\in R^{D\times d_t\times N}\) at the N recording points for D days as an example, where dt is the number of recording times per day (if recordings are taken once every 5 minutes, then \(d_t\) = 288).
For each recording point, the oneday traffic data is treated as a vector, then a set of multi-day traffic data is denoted as a vector sequence. For example, the vector sequence obtained at recording point n (\(n\in N\)) is denoted as \(X^f_n = (w_{n1}, w_{n2}, \cdots, w_{nD})\), \(w_{nd}\in\mathbb{R}^{d_t}\), where \(d\in[1, D]\).
提取交通流量信息:
In this way, the vector sequence of the recording point n is transformed into a probability distribution \(P_n\{X_d = m_{nd}\}\), and each day has a probability mass \(m_{nd}\in[0, 1]\) and \(\sum_d m_{nd}=1\), which denotes the proportion of traffic volume in a certain day over a period of time.
P.S.
- 为什么流量信息是一天的流量占一段时间的比?
- 这里为什么用二范数,而不是直接对一天各个时段的流量进行求和?
2. 时空感知距离与时空关联图
概率分布变换代价(使用Wassertein距离):
利用余弦距离作为变换代价:
使用Wassertein距离计算两个节点的差异:
节点间的关联度矩阵\(\boldsymbol{A}_{S T A D} \in \mathbb{R}^{N \times N}\):
通过设定稀疏水平\(P_{sp}\)(超参数),选取每行 \(N_r=N\times P_{sp}\) 个节点(关联度最大的几个),其余的设为0,就获得了时空关联图 \(\boldsymbol{A}_{S T R G} \in \mathbb{R}^{N \times N}\)。
对时空关联图 \(\boldsymbol{A}_{STRG}\) 进行二值化,获得时空感知图 \(\boldsymbol{A}_{STAG}\)。
P.S.
- 为什么用余弦变换作为代价?
- 为什么用Wassertein距离计算节点的差异(关联)?
- 为什么关联图中只选了 \(P_{sp}\) 个节点?
3. 时空注意力块
这里用到了多头注意力机制,参考Attention Is All You Need。
时间注意力
- 这里的 \(\mathcal{X}^{(l)}\in\mathbb{R}^{N\times c^{(l-1)}\times M}\) 是输入,经过reshape,变成 \(\mathcal{X}'^{(l)}\in\mathbb{R}^{c^{(l-1)}\times M\times N}\)。
- \(\mathcal{X}^{\prime(l)} \boldsymbol{W}_{q}^{(l)} \triangleq Q^{(l)}, \quad \boldsymbol{X}^{(l)} \boldsymbol{W}_{k}^{(l)} \triangleq K^{(l)}, \quad \boldsymbol{X}^{(l)} \boldsymbol{W}_{v}^{(l)} \triangleq V^{(l)}\) Q和K用来计算不同节点之间的相关性(注意力系数),\(W_{q,k,v}^{(l)}\in \mathbb{R}^{N\times d}\),\(Q^{(l)}, K^{(l)}, V^{(l)} \in \mathbb{R}^{c^{(l-1)} \times M \times d}\)。
- 注意力*:
- Scaled Dot-Product(论文Attention Is All You Need里来的。)\(A^{(l)}=\frac{Q^{(l)} K^{(l)^{\top}}}{\sqrt{d_{h}}}+A^{(l-1)}\)
- 注意力:\(\operatorname{Att}\left(Q^{(l)}, K^{(l)}, V^{(l)}\right)=\operatorname{Softmax}\left(A^{(l)}\right) V^{(l)}\)
- 多头注意力机制(论文Attention Is All You Need里来的。)\(O^{(h)}=\operatorname{Att}\left(Q \boldsymbol{W}_{q}^{(h)}, K \boldsymbol{W}_{k}^{(h)}, V \boldsymbol{W}_{v}^{(h)}\right)\) 其中 \(\boldsymbol{W}_{q, k, v}^{(h)} \in \mathbb{R}^{d \times d_{h}}\left(d_{h}=d / H\right)\)。
- Concat:\(O=\left[O^{(1)}, O^{(2)}, \ldots, O^{(H)}\right]\) 经过Reshape,获得 \(\boldsymbol{O} \in \mathbb{R}^{c^{(l-1)} \times M \times H \times d_{h}}\)。
- Linear层把输出的大小还原,获得 \(O^{\prime} \in \mathbb{R}^{c^{(l-1)} \times M \times N}\)。
- 带残差的LayerNorm:\(Y=\operatorname{LayerNorm}\left(\operatorname{Linear}\left(\operatorname{Reshape}(O)\right)+X^{\prime}\right)\) 输出 \(Y \in \mathbb{R}^{c^{(l-1)} \times M \times N}\)。
P.S. - 原文中的LayerNorm公式错了。
- \(d\) 和 \(d_h\)是如何决定的?
- Scaled Dot-product 为什么用残差?
空间注意力
- Reshape:\(Y \in \mathbb{R}^{c^{(l-1)} \times M \times N}\) 转换成 \(Y^{\#} \in \mathbb{R}^{c^{(l-1)} \times N \times M}\)。
- Conv:
- 先把 \(Y^{\#} \in \mathbb{R}^{c^{(l-1)} \times N \times M}\) 中的时间维度 \(M\) 映射到 \(d_E\) 维的高维空间。(怎么映射?为什么要映射?)
- 对特征维 \(c^{(l-1)}\) 做一维卷积,获得二维矩阵\(Y'\in \mathbb{R}^{N\times d_E}\)。(卷积核大小是 \(c^{(l-1)}\)?)
- Embedding:
Then, we add positional information to \(Y'\) through an embedding layer to get \(Y_E\).
- 注意力机制(这里的空间注意力其实只算了注意力中的相关性,在下一步卷积中使用):
Instead of using the self-attention fully generated from \(Y_E\) as in conventional transformers, here we introduce the temporal-spatial relevance graph \(A_{STRG}\) with learned correlation between nodes to amend the attention in the SA module. Thus the improved spatial attention with H heads is denoted as: \(\begin{array}{c} \boldsymbol{P}^{(h)}=\operatorname{Softmax}\left(\frac{\left(\boldsymbol{Y}_{E} \boldsymbol{W}_{k}^{\prime(h)}\right)^{\top}\left(\boldsymbol{Y}_{E} \boldsymbol{W}_{q}^{\prime(h)}\right)}{\sqrt{d_{h}}}+\boldsymbol{W}_{m}^{(h)} \odot \boldsymbol{A}_{S T R G}\right) \\ \mathcal{P}=\left[\boldsymbol{P}^{(1)}, \boldsymbol{P}^{(2)}, \ldots, \boldsymbol{P}^{(H)}\right] \end{array}\) where \(\boldsymbol{W}_{k}^{\prime(h)}, \boldsymbol{W}_{q}^{\prime(h)} \in \mathbb{R}^{d_{E} \times d_{h}}, \boldsymbol{W}_{m}^{(h)} \in \mathbb{R}^{N \times N}\) are learnable parameters, \(\odot\) is the element-wise Hadamard product, \(\boldsymbol{W}_m^{(h)}\) is used to amend \(\boldsymbol{A}_{STRG}\) for adjusting the attention of each head \(\boldsymbol{P}^{(h)}\in\mathbb{R}^{N\times N}\), and the output \(\mathcal{P}\in\mathbb{R}^{H\times N\times N}\) denotes the dynamic spatial-temporal attention tensor by combining the outputs from each head.
P.S.
- Conv中的高维映射怎么映射?为什么要映射?
- Conv中的一维卷积怎么把整个 \(c^{(l-1)}\) 卷没了?
- 这里为什么不用残差了,而是用 \(A_{STRG}\)?
4. Spatial-Temporal Convolution Block
空间图卷积
这里使用了ChebyNet卷积,待深入。
图信号:\(x = \boldsymbol{x_t}\in\mathbb{R}^N\)
归一化的拉普拉斯矩阵:\(\tilde{\boldsymbol{L}}=\frac{2}{\lambda_{\max }}\left(\boldsymbol{D}-\boldsymbol{A}^*\right)-\boldsymbol{I}_N\)
邻接矩阵:\(\boldsymbol{A}^*=\boldsymbol{A}_{S T A G}\)
参数:\(\boldsymbol{\theta} \in \mathbb{R}^K\)
第k个头的空间-时间注意力矩阵:\(\boldsymbol{P}^{(k)} \in \mathbb{R}^{N \times N}\)
P.S.
- 为什么把P直接哈达玛乘在切比雪夫多项式上?
时间门控卷积(Temporal gated convolution)
论文提出了M-GTU (Multi-scale Gated Tanh Unit) 卷积模块捕捉交通流数据中的动态时间信息。
M-GTU由三个GTU模块组成,每个GTU由Convolution和Gating两步组成。
- 卷积:
- 输入:\(\mathcal{Z}^{(l)}\in\mathbb{R}^{N\times M\times C^{(l)}}\)。
- 卷积核:\(\Gamma\in\mathbb{R}^{1 \times S \times c^{(l) \times 2c^{(l)}}}\),其中,\(c^{(l)}\) 是输入的通道数(节点的特征数),\(2c^{(l)}\) 是输出的通道数,卷积核的大小是 \(1\times S\),1是节点个数维度上的大小,S是时间维度上的大小。
- 卷积公式:\(\mathcal{Z}'^{(l)}=\Gamma*\mathcal{Z}^{(l)}\in\mathbb{R}^{N \times(M-(S-1))\times 2C^{(l)}}\)
- Gating:把 \(\mathcal{Z}'^{(l)}\) 沿着通道维度等分成两份 \(E\) 和 \(F\),则输出为 \(\phi(E)\odot\sigma(F)\in\mathbb{R}^{N\times{(M-(S-1))\times2C^{(l)}}}\)
- Concat:虽然图上有Pooling,但是代码中并没有,使用三个不同的卷积核 \(S_1,S_2,S_3\) 卷积后,将三个输出沿着时间轴拼接。
- Linear:将Concat的结果放进Linear层,恢复原来的大小。
self.fcmy = nn.Sequential(
nn.Linear(3 * num_of_timesteps - 12, num_of_timesteps),
nn.Dropout(0.05),
)
P.S.
- 为什么要用GTU?
Datasets
文章使用了PEMS03、PEMS04、PEMS07、PEMS08三个数据集。