学习笔记:CL4ST
Spatio-Temporal Meta Contrastive Learning
时空元对比学习
CIKM2023
作者:Jiabin Tang, Lianghao Xia, Jie Hu, Chao Huang
论文地址: https://dl.acm.org/doi/10.1145/3583780.3615065
代码地址: https://github.com/HKUDS/CL4ST
总结
是一篇使用了对比学习的模型,其中”可学习的视图生成“部分对边和点实现了数据增强,其中一些MLPs是通过”元网络“来计算参数的。这里或许体现了一些元学习的思想。但是元学习参数的更新和普通网络参数的更新没有进行区分,感觉也不是正宗的元学习。
从论文所提供的实验结果,结合我其它的复现结果,整体来看,CL4ST准确率与STAEformer和TPGNN相当,略差于PDFormer和TrendGCN,优于其它模型如Graph WaveNet、AGCRN等。准确率还不错。
从速度上看,单看PEMS04(论文只给了这个数据集的超参数设置),CL4ST是最慢的 7h40min,对比起来STAEFormer是1h30min,PDFormer是6h30min,TrendGCN是50min。速度是一个缺点。
复现
显卡A100,数据集 PEMS04,耗时7h40min,共145 epoch(最大200 epoch),MAE: 18.55, RMSE: 30.12, MAPE: 12.2764%, sMAPE: 19.1496%, Corr: 0.0000。此结果与论文给的差不多
模型
图1:CL4ST模型图 |
模型图如上,步骤如下:
一、点嵌入和边嵌入,元视图生成
通过图增强的常规方法mask(mean), keep, drop,结合GumbleSoftmax重参数化方法与MLP,计算嵌入f_v和f_e,和损失L_gen(KL散度)
二、注意力
对X使用多头图注意力获得空间图 Hs 。空间邻接矩阵 As由高斯核函数或邻居得到。由Hs获得时间图Ht。时间邻接矩阵形状是TxT的,全1
(有点疑惑,时间图怎么是由空间图计算得到的)
三、解码
采用位置嵌入和时间嵌入
将中间变量H、时空嵌入、数据X通过拼合和MLP,得到预测结果。
四、对比学习
应该是经典的方法吧:让同一个样本的增强数据的表征和原始数据的表征尽可能相似,而和不同样本的表征尽可能区分。以此来训练模型的表征能力。
五、优化目标
预测误差(包括原始数据和增强数据)、对比损失、元视图生成损失(包括空间和时间)之和
至于参数更新方法,论文没有说,应该可以看出,这里的元视图也不属于元学习。
表1:CL4ST对比试验结果 |
标注
这里标注一些我不理解的地方,虽然我也没打算继续钻研下去了:
- 这里端到端是什么意思呢?
our goal for learnable view generation is to design an end-to-end differentiable framework that can learn an augmented view on the graph G.
可学习视图生成的目的是设计一个能够学习增强视图的端到端可导框架。
- 为什么时间图是全1的呢?