Graph Wave Net模型中的数据集hdf5和pkl文件的读取问题
引入:GraphWaveNet的流量数据的文件格式是.h5,路网结构文件格式是.pkl,它们怎么打开呢?
HDF5
HDF5文件一般以 .h5 或者 .hdf5 作为后缀名,其中包含两种结构:Group(文件夹)和Datasets(数据)
python可以使用 h5py 或 pandas 打开.h5文件
h5py
path = 'metr-la.h5'
f = h5py.File(path, 'r')
# 查看f(文件夹)中有哪些东西:
for k1 in f.keys():
d = f[k1]
print(d.name, ":", type(d))
# 输出:/df : <class 'h5py._hl.group.Group'>
#于是我们就知道了,f中仅包含一个group
g = f['df']
for k2 in g:
d = g[k2]
print(d.name, ":", type(d))
"""
输出:
/df/axis0 : <class 'h5py._hl.dataset.Dataset'>
/df/axis1 : <class 'h5py._hl.dataset.Dataset'>
/df/block0_items : <class 'h5py._hl.dataset.Dataset'>
/df/block0_values : <class 'h5py._hl.dataset.Dataset'>
"""
#也就是,/df下有四个dataset
for k2 in g:
d = g[k2][:]
print(k2, ":", type(d), d.shape)
"""
axis0 : <class 'numpy.ndarray'> (207,)
axis1 : <class 'numpy.ndarray'> (34272,)
block0_items : <class 'numpy.ndarray'> (207,)
block0_values : <class 'numpy.ndarray'> (34272, 207)
"""
# 实际上是四个numpy.ndarray数据,可以用这种方法提取出来
pandas
import pandas as pd
path = 'metr-la.h5'
df = pd.read_hdf(path).values
print(df.shape)
# 输出:(34272, 207)
# df 就是交通数据
pickle
首先,这个模型给的pkl文件是有问题的,需要转换一下才能打开。转换代码:
original = "adj_mx.pkl"
destination = "new_adj_mx.pkl"
content = ''
outsize = 0
with open(original, 'rb') as infile:
content = infile.read()
with open(destination, 'wb') as output:
for line in content.splitlines():
outsize += len(line) + 1
output.write(line + str.encode('\n'))
print("Done. Saved %s bytes." % (len(content)-outsize))
然后,可以用这种方式,打印成txt格式:
import sys
sys.getdefaultencoding()
import pickle
import numpy as np
np.set_printoptions(threshold=1000000000000000)
path = 'new_adj_mx.pkl'
file = open(path,'rb')
inf = pickle.load(file,encoding='iso-8859-1') #读取pkl文件的内容
print(inf)
#fr.close()
inf=str(inf)
obj_path = 'new_adj_mx.txt'
ft = open(obj_path, 'w')
ft.write(inf)
而在模型代码中,是这样读取的:
import pickle
pickle_file = 'data/sensor_graph/adj_mx.pkl'
try:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError as e:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f, encoding='latin1')
except Exception as e:
print('Unable to load data ', pickle_file, ':', e)
raise
sensor_ids, sensor_id_to_ind, adj_mx = pickle_data
可以看出pickle文件包含了多个数据。注意,这段代码运行时会报错,可能是pkl文件有问题。