-
二进制文件
-
包含多个tf.train.Example
-
Example是protocol buffer数据标准实现,包含一系列tf.train.feature属性
-
feature是key(string)-value(bytes_list || float_list || int64_list)键值对
-
合理存储,二进制编码,加快数据读取和预处理速度
-
数据转为tfrecord文件
writer = tf.python_io.TFRecordWriter(out_file_name) # 1. 定义 writer对象 for data in dataes: context = dataes[0] question = dataes[1] answer = dataes[2] """ 2. 定义features """ example = tf.train.Example( features = tf.train.Features( feature = { 'context': tf.train.Feature( int64_list=tf.train.Int64List(value=context)), 'question': tf.train.Feature( int64_list=tf.train.Int64List(value=question)), 'answer': tf.train.Feature( int64_list=tf.train.Int64List(value=answer)) })) """ 3. 序列化,写入""" serialized = example.SerializeToString() writer.write(serialized)
-
tfrecord文件读取
从tfrecord文件创建TFRecordDataset
dataset = tf.data.TFRecordDataset('xxx.tfrecord')
解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:
feats = tf.parse_single_example(serial_exmp, features=data_dict)
其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:
def parse_exmp(serial_exmp): #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)}) image = tf.decode_raw(feats['feature'], tf.float32) label = feats['label'] shape = tf.cast(feats['shape'], tf.int32) return image, label, shape
解析tfrecord文件中的所有记录,使用dataset的map方法,如下:
dataset = dataset.map(parse_exmp)
map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:
dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)
解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:
iterator = dataset.make_one_shot_iterator()
batch_image, batch_label, batch_shape = iterator.get_next()
要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:
handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes) image, label, shape = iterator.get_next()
然后为各个dataset创建handle,以feed_dict传入placeholder,如下:
with tf.Session() as sess: handle_train, handle_val, handle_test = sess.run([x.string_handle() for x in [iter_train, iter_val, iter_test]]) sess.run([loss, train_op], feed_dict={handle: handle_train}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】