学习笔记:ST-MetaNet
Urban Traffic Prediction from Spatio-Temporal Data Using Deep Meta Learning
使用深度元学习进行城市交通预测
期刊会议: KDD2019
作者:Zheyi Pan, Yuxuan Liang, Weifeng Wang, Yong Yu, Yu Zheng, Junbo Zhang
论文地址: https://dl.acm.org/doi/10.1145/3292500.3330884
代码地址: (mxnet) https://github.com/panzheyi/ST-MetaNet
总结
感觉这篇论文的元学习并不正宗。原因有二:
- 参数更新上,元学习经典的更新方法是,普通网络的参数和元学习的参是分开更新的,是区别对待的。
- 这篇论文的元学习实际上学习的是节点嵌入,它的目的并不是元学习的”让普通网络参数的潜力最大“。
背景
现有的方法对所有点一视同仁,采用相同的模型,无法区分内在联系,如地理位置和时空关系等,这些关系在没有前置知识时很难学习到。还有一系列研究采用多任务学习方法,对每个节点建立多个子模型,结合相似度约束,放在一起训练。但是这些方法利用的关系太弱,导致模型效果不佳。本模型从节点属性和边属性上提取元知识,用于建模时空关联,生成预测网络的权重。
模型
此论文采用Seq2Seq框架,即用输入的数据通过若干个RNN组合而成的编码器得到一个上下文变量,再用这个上下文变量经过若干个RNN组合而成的解码器得到输出。
从下面的架构图中可以看出,框架有三层,分别是RNN、Meta-GAT和Meta-RNN。其中Meta的作用就是给GAT和RNN中的可学习参数提供了初始化的值。
图1:ST-MetaNet结构图 |
ST-MetaNet中所有的元学习结构都是全连接层。元知识分为点元知识(NMK)和边元知识(EMK)。
下图以Meta-GAT为例。首先元知识学习器(Meta-knowledge Learner)通过输入的点属性(POI,GPS位置等)和边属性(连通性、距离等),利用全连接层学习到元知识MK(包括NMK和EMK)。MK包含了点的关系和点之间的关系。Meta-GAT再利用MK,通过全连接层,得到GAT的权重。
图2:Meta-GAT过程图 |
参数更新方法
ST-MetaNet包含两种参数:
- \(\omega_1\)是普通网络的参数,其梯度就是 \(\nabla_{\omega_{1}}\mathcal{L}_{\mathrm{train}}\)
- \(\omega_2\)是元学习部分的参数,其梯度为\(\nabla_{\omega_{2}}\mathcal{L}_{\mathrm{train}}=\nabla_{\theta}\mathcal{L}_{\mathrm{train}}\nabla_{\omega_{2}}\theta\)
其中\(\theta\)是元学习器生成的参数。
其实还是不太懂这两项的更新有什么区别。在代码中,没看到有什么区分。感觉这个模型的元学习部分的参数和普通网络的参数是一样的,没有区别。因此,感觉这个模型的元学习并不正宗。
元知识的含义
As shown in Figure 3 (b), two meta-knowledge learners respectively employ different FCNs, in which input is the attribute of a node or an edge, and the corresponding output is the embedding (vector representation) of that node or edge.
如图3(b)所示,两个元知识学习器分别使用不同的FCN,其中输入是一个节点或边的属性,对应的输出是该节点或边缘的嵌入(矢量表示)。
原文这句话,意味着节点嵌入是在元知识学习器的输出。论文进行了一个实验来分析节点嵌入,下图是一个节点嵌入与其最相似的十个节点嵌入的相似度值,展示了ST-MetaNet对于节点嵌入的优势。
图3:节点嵌入效果比较 |
可以总结出两点:
- 明明元知识学习只是用全连接层实现的啊,为什么这样的节点嵌入效果更好呢?
- 这张图证明了论文的元知识学习实际上学习到的是节点之间的关系,以此来指导网络的初始参数,实现节点的不同
ST-MetaNet+
Spatio-Temporal Meta Learning for Urban Traffic Prediction
期刊会议:IEEE TKDE 2020
作者:Zheyi Pan; Wentao Zhang; Yuxuan Liang; Weinan Zhang; Yong Yu; Junbo Zhang; Yu Zheng
论文地址: https://ieeexplore.ieee.org/document/9096591
代码地址:无
此论文和ST-MetaNet是同一作者,模型图如下,可以看出很相似,是上一个模型的改进。
图4:ST-MetaNet+结构图 |
以Meta-GAT+为例,主要改进之处,就是将Meta-GAT中的Meta learner变成了Context learner + Fusion Gate + Meta Learner,也就是增加了RNN对GAT权重的影响。
图5:Meta-GAT+过程图 |
两个对比试验:
表1:TAXI-BJ |
表2:PEMS-BAY |
与老熟人Graph WaveNet进行对比,发现打得有来有回。在TAXI-BJ上ST-MetaNet+更好一些,而在PEMS-BAY上,Graph WaveNet则略优。对此,论文的解释是,在PEMS-BAY上,地理信息较少,只有GPS位置和距离,因此本模型效果平庸。(看了下TAXI-BJ数据集,数据分为32x32的网格,有出入流量,和PEMS-BAY其实也差不多。一些节日、天气信息,其实是所有节点的共同属性,不影响元学习。因此这个理由不是很充分。)