采用tfrecord形式读写训练数据
tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。尤其在面对海量数据时,使用常用的内存读取方式变得不切实际,tfrecored方式为我们带来了更大的便捷,同时还可以配合shuffe大大提高model的train效率。
示例代def convert_tfrecord(data, label):
"""保存为tfrecord形式 :param data: :param label: :return: """ record_path = './resources/train.tfrecord' # 调用example和features函数将数据格式化保存起来 cnt = 0 writer = tf.python_io.TFRecordWriter(record_path) for d, s, l in zip(data[0], data[1], label): if cnt % 100 == 0: print('write example {}'.format(cnt)) cnt += 1 example = tf.train.Example( features=tf.train.Features( feature={ 'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=d)), 'score': tf.train.Feature(float_list=tf.train.FloatList(value=s)), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[l])) } ) ) writer.write(example.SerializeToString()) writer.close() print('写入ok') # 读取,batch 取 filename_queue = tf.train.string_input_producer([record_path],) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue)
features = tf.io.parse_single_example(serialized_example, features={ 'sample': tf.io.FixedLenFeature([9], tf.int64), 'score': tf.io.FixedLenFeature([9], tf.float32), 'label': tf.io.FixedLenFeature([1], tf.int64), }) is_batch = True if is_batch: batch_size = 3 min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batch_size samples, scores, labels = tf.train.shuffle_batch([features['sample'], features['score'], features['label']], batch_size=batch_size, num_threads=3, capacity=capacity, min_after_dequeue=min_after_dequeue) with tf.compat.v1.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1000): # 从会话中取出数据 sample, score, label = sess.run([samples, scores, labels]) print(sample) print(score) print('###########') coord.request_stop() coord.join(threads) print('ok')
时刻记着自己要成为什么样的人!