TF-GNN踩坑记录(三)
引言
在Tensorflow-GNN中使用batch size除了需要注意上面的链接问题之外,最近我在调试的发现,使用了merge_batch_to_components() 之后,使用TF-GNN的Readout模块,它会默认merge之后的graph为一张图读出所有节点的数据组成一个矩阵,而不区分batch中的每一张子图,故会导致数据的结构被修改,导致模型的表现与预期的差距较大。
解决方案
out = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "sum", node_set_name="node")(graph)
使用Pool替代Readout,该代码的具体的作用是从merge之后的图中,读出每一张子图(component)上 的节点数据,并对每个子图的节点数据进行pooling,如这里使用加法做为pooling的方式,并把这些子图pooling之后的数据拼接成一个矩阵存储在context中。这个矩阵的数据和原始输入的graph是一一对应的,如输入的batch size是32,这个矩阵的行数也为32行,每一行对应一张graph。
本文来自博客园,作者:LoveFishO,转载请注明原文链接:https://www.cnblogs.com/lovefisho/p/16627062.html