TF-GNN踩坑记录(三)

引言

Batch size问题

 

在Tensorflow-GNN中使用batch size除了需要注意上面的链接问题之外,最近我在调试的发现,使用了merge_batch_to_components() 之后,使用TF-GNN的Readout模块,它会默认merge之后的graph为一张图读出所有节点的数据组成一个矩阵,而不区分batch中的每一张子图,故会导致数据的结构被修改,导致模型的表现与预期的差距较大。

 

解决方案

out = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "sum", node_set_name="node")(graph)

使用Pool替代Readout,该代码的具体的作用是从merge之后的图中,读出每一张子图(component)上 的节点数据,并对每个子图的节点数据进行pooling,如这里使用加法做为pooling的方式,并把这些子图pooling之后的数据拼接成一个矩阵存储在context中。这个矩阵的数据和原始输入的graph是一一对应的,如输入的batch size是32,这个矩阵的行数也为32行,每一行对应一张graph。

 

posted @ 2022-08-26 11:30  LoveFishO  阅读(126)  评论(0编辑  收藏  举报