TFRecord文件
对于数据进行统一的管理是很有必要的.TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.
小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.
一.生成TFrecords
Ⅰ tf.python_io.TFRecordWriter 类
把记录写入到TFRecords文件的类.
__init__(path,options=None)
作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象
close()
作用:关闭对象.
write(record)
作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录
Ⅱ.tf.train.Example
这个类是非常重要的,TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.
函数:
__init__(**kwargs)
这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.
SerializeToString()
作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.
Ⅲ.tf.train.Features
函数:
__init__(**kwargs)
作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.
Ⅳ.tf.train.Feature
class tf.train.Feature
属性:
bytes_list
float_list
int64_list
函数:
__init__(**kwargs)
作用:构造一个Feature对象,一般使用的时候,传入 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList对象.
Ⅴ.tf.train.Int64List, tf.train.BytesList, tf.train.FloatList
使用的时候,一般传入一个具体的值,比如学习任务中的标签就可以传进value=tf.train.Int64List,而图片就可以先转为字符串的格式之后,传入value=tf.train.BytesList中.
存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64。
def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) train_filename = 'train.tfrecords' with tf.python_io.TFRecordWriter(train_filename) as tfrecord_writer: for i in range(len(images)): # read in image data by tf img_data = tf.gfile.FastGFile(images[i], 'rb').read() # image data type is string label = labels[i] # get width and height of image image_shape = cv2.imread(images[i]).shape width = image_shape[1] height = image_shape[0] # create features feature = {'train/image': _bytes_feature(img_data), 'train/label': _int64_feature(label), # label: integer from 0-N 'train/height': _int64_feature(height), 'train/width': _int64_feature(width)} # create example protocol buffer example = tf.train.Example(features=tf.train.Features(feature=feature)) # serialize protocol buffer to string tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close()
img_raw = img.tobytes()#将图片转化为二进制格式
# 为图像建Example
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
# 写入tfrecord文件
num_pic += 1
writer.write(example.SerializeToString())
二 Tensorflow读写TFRecords文件
在使用slim之类的tensorflow自带框架的时候一般默认的数据格式就是TFRecords,在训练的时候使用TFRecords中数据的流程如下:
使用input pipeline读取tfrecords文件/其他支持的格式,然后随机乱序,生成文件序列,读取并解码数据,输入模型训练。
如果有一串jpg图片地址和相应的标签:images
和labels
首先用tf.train.string_input_producer
读取tfrecords文件的list建立FIFO序列,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱,然后定义TFRecordReader读取上面的序列返回下一个record,用tf.parse_single_example
对读取到TFRecords文件进行解码,根据保存的serialize example和feature字典返回feature所对应的值。此时获得的值都是string,需要进一步解码为所需的数据类型。把图像数据的string reshape成原始图像后可以进行preprocessing操作。此外,还可以通过tf.train.batch
或者tf.train.shuffle_batch
将图像生成batch序列。
由于tf.train
函数会在graph中增加tf.train.QueueRunner
类,而这些类有一系列的enqueue选项使一个队列在一个线程里运行。为了填充队列就需要用tf.train.start_queue_runners
来为所有graph中的queue runner启动线程,而为了管理这些线程就需要一个tf.train.Coordinator
来在合适的时候终止这些线程。
import tensorflow as tf import matplotlib.pyplot as plt data_path = 'train.tfrecords' with tf.Session() as sess: # feature key and its data type for data restored in tfrecords file feature = {'train/image': tf.FixedLenFeature([], tf.string), 'train/label': tf.FixedLenFeature([], tf.int64), 'train/height': tf.FixedLenFeature([], tf.int64), 'train/width': tf.FixedLenFeature([], tf.int64)} # define a queue base on input filenames filename_queue = tf.train.string_input_producer([data_path], num_epoches=1) # define a tfrecords file reader reader = tf.TFRecordReader() # read in serialized example data _, serialized_example = reader.read(filename_queue) # decode example by feature features = tf.parse_single_example(serialized_example, features=feature) image = tf.image.decode_jpeg(features['train/image']) image = tf.image.convert_image_dtype(image, dtype=tf.float32) # convert dtype from unit8 to float32 for later resize label = tf.cast(features['train/label'], tf.int64) height = tf.cast(features['train/height'], tf.int32) width = tf.cast(features['train/width'], tf.int32) # restore image to [height, width, 3] image = tf.reshape(image, [height, width, 3]) # resize image = tf.image.resize_images(image, [224, 224]) # create bathch images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) # capacity是队列的最大容量,
#min_after_dequeue是dequeue后最小的队列大小,num_threads是进行队列操作的线程数。
# initialize global & local variables init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # create a coordinate and run queue runner objects coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for batch_index in range(3): batch_images, batch_labels = sess.run([images, labels]) for i in range(10): plt.imshow(batch_images[i, ...]) plt.show() print "Current image label is: ", batch_lables[i] # close threads coord.request_stop() coord.join(threads) sess.close()
tf.decode_raw与tf.cast的区别
tf.decode_raw函数的意思是将原来编码为字符串类型的变量重新变回来,这个方法在数据集dataset中很常用,
因为制作图片源数据一般写进tfrecord里用to_bytes的形式,也就是字符串。这里将原始数据取出来 必须制定原始数据的格式,原始数据是什么格式这里解析必须是什么格式,要不然会出现形状的不对应问题!
例如原始数据是tf.float64然后to_bytes,但是用tf.decode_raw解析的时候使用了tf.float32,那么形状跟值都会跟原始数据有差别,后面传入网络的时候一定会报
tensorflow : Input to reshape is a tensor with 16384 values, but the requested shape has 49152 这种错误
tf.cast
这个函数主要用于数据类型的转变,不会改变原始数据的值还有形状的,
retyped_images = tf.cast(decoded_images, tf.float32)
labels = tf.cast(features['label'],tf.int32)
这里retyped_images原来是tf.float64形状 labels是tf.uint8。tf.cast还可以用于将numpy数组转化为tensor
tf.decode_raw()解析固定长度的数据,对于数据格式有一定的要求,应为tf.uint8
import tensorflow as tf import cv2 def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) train_filename = 'train.tfrecords'with tf.python_io.TFRecordWriter(train_filename) as tfrecord_writer: for i in range(len(images)): # read in image data by tf img_data = tf.gfile.FastGFile(images[i], 'rb').read() # image data type is string label = labels[i] # get width and height of image image_shape = cv2.imread(images[i]).shape width = image_shape[1] height = image_shape[0] # create features feature = {'train/image': _bytes_feature(img_data), 'train/label': _int64_feature(label), # label: integer from 0-N'train/height': _int64_feature(height), 'train/width': _int64_feature(width)} # create example protocol buffer example = tf.train.Example(features=tf.train.Features(feature=feature)) # serialize protocol buffer to string tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现
2019-03-21 问题 1936: [蓝桥杯][算法提高VIP]最大乘积
2019-03-21 指针 链表
2019-03-21 蓝桥杯 第九届 日志 统计
2019-03-21 子串!=子序列
2019-03-21 poj 3061