Tensorflow学习记录 --TensorFlow高效读取数据tfrecord

Tensorflow学习过程中tfrecord的简单理解

1 TFRecord的介绍:

一般使用直接将数据加载到内存的方式来存储数据量较小的数据,然后再分batch输入网络进行训练。如果数据量太大,这种方法是十分消耗内存的,这时可以使用tensorflow提供的队列queue从文件中提取数据(比如csv文件等)。还有一种较为常用的,高效的读取方法,即使用tensorflow内定标准格式——TFRecords.作者也是刚接触tensorflow,对日常学习遇到的问题做简单记录,有不对地方需要指正。

1.1 什么是TFRecord?

TFRecord是谷歌推荐的一种常用的存储二进制序列数据的文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

2 代码及相关简介

2.1 构建写入数据的writer

import numpy as np 
import tensorflow as tf 
writer = tf.python_io.TFRecordWriter('test.tfrecord')

2.2 TFRecord

TensorFlow经常使用 tf.Example 来写入,读取TFRecord数据。

通常tf.example有下面几种数据结构:

  • tf.train.FloatList: 可以使用的类型包括 float和double
  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
  • f.train.BytesList: 可以使用的类型包括 string和byte

TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:

#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) 
 
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

下面以一个具体的简单例子来介绍tf.example

for k in range(0, 3):
    x = 0.1712 + k
    y = [1+k, 2+k]
    z = np.array([[1,2,3],[4,5,6]]) + k
    z = z.astype(np.uint8)
    z_raw = z.tostring()
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
                       'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
                       'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}))
    serialized = example.SerializeToString()
    writer.write(serialized)
writer.close()

x,y,z分别是以float,int64和string的形式存储的,注意观察下面语句:

feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
           'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
           'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}

value的值是一个list形式,x定义的为一个数,value的值应为[x],同样y定义的格式就是一个list所以value的值直接为y即可,z_raw是由z转换过来的string形式,对应的value值与x的形式应该是一样的。

2.3 创建文件读取队列并读取其中内容(字典格式)

#output file name string to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs = None)
#Create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#Get feature from serialized example
features = tf.parse_single_example(serialized_example,
                features = {'x': tf.FixedLenFeature([],tf.float32),
                            'y': tf.FixedLenFeature([2],tf.int64),
                            'z': tf.FixedLenFeature([],tf.string)})

2.4 读取数据

x_out  = features['x']
y_out  = features['y']
z_raw_out = features['z']
z_out = tf.decode_raw(z_raw_out,tf.uint8)
z_out = tf.reshape(z_out, [2,3])
print(x_out)
print(y_out)
print(z_out)

显示结果为:

Tensor("ParseSingleExample_2/ParseSingleExample:0", shape=(), dtype=float32)
Tensor("ParseSingleExample_2/ParseSingleExample:1", shape=(2,), dtype=int64)
Tensor("Reshape_1:0", shape=(2, 3), dtype=uint8)

3 以存储图片为例理解TFRecord的应用

使用Tensorflow训练网络时,为提高数据的读取效率,一般都采用TFRecords格式。初学CNN我们使用了手写数字数据集学习,这些都是做好的数据集,我们可以直接使用,比如MNIST,CIFAR_10等。现在我们还不是很清楚怎样输入训练的图片,此时就要用到TFRecord来制作自己的数据集。

3.1 将图片转换成tfrecords格式

假设我们的输入的图片需要三种信息,图片的名字,图片维度以及图片的内容:name shape content
输入图片以及输出tfrecord文件:

input_photo = r'D:\Furh\jupyter code\Tensorflow Tips\data\dog.jpg'
output_file = r'D:\Furh\jupyter code\Tensorflow Tips\dog.tfrecord'
# 使用 TFRecordWriter 将信息写入到 TFRecord 文件
writer = tf.python_io.TFRecordWriter(output_file)
#读取图片进行解码
image = tf.read_file(input_photo)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
    image_new = sess.run(image)
    shape = image_new.shape
    #将图片转换成string 
    image_data = image_new.tostring()
    print(type(image_new))
    print(len(image_data))
    name = bytes('dog',encoding = 'utf-8')
    print(type(name))
    # 创建Example对象,将所有的Features填充进去
    example = tf.train.Example(
                    features = tf.train.Features(
                        feature = {
                            'name': tf.train.Feature(bytes_list = tf.train.BytesList(value = [name])),
                            'shape': tf.train.Feature(int64_list = tf.train.Int64List(value = [shape[0],shape[1],shape[2]])),
                            'data': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data]))
                        }
                    ))
    # 将example序列化成string类型写入
    writer.write(example.SerializeToString())
writer.close()

Note:

  • Feature 中value应该是列表形式,当数据不是列表时,加上[]
  • 解码后的图片要转化成string数据,再填充
  • example需要使用SerializeToString()进行序列化

3.2 TFRecord 文件读取成图片

#解析数据 
def parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((),tf.string),
        'shape': tf.FixedLenFeature([3],tf.int64), #这里制定维度3
        'data' : tf.FixedLenFeature((),tf.string)
    }
    #在解析example时,用现成的API: tf.parse_single_example
    parsed_features = tf.parse_single_example(example_photo,features = features)
    return parsed_features

def read_test(input_file):
    #使用dataset读取TFRecord文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(parse_record)
    iterator = dataset.make_one_shot_iterator()
    
    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode
        img_data = features['data']
        shape = features['shape']
        
        #从bytes数组中加载图片原始数据,并重新reshape,结果是ndarray数组
        img_data = np.fromstring(img_data, dtype=np.uint8) #获取解析后的string数据,并把数据还原成unit8
        image_data = np.reshape(img_data,shape)
        
        plt.figure()
        plt.imshow(image_data)
        plt.show()
        
        #将数据重新编码成jpg图片保存
        img = tf.image.encode_jpeg(image_data)
        #把图片保存到本地    
        tf.gfile.GFile('dog_encode,jpg', 'wb').write(img.eval())
read_test('dog.tfrecord')

Note:
在使用dataset进行样本解析之前,我们需要按照先定义一个解析字典,告诉dataset如何去解析每个样本,这个字典就是用来指定对于每条输入样本的每一列应该用什么的feature去解析,dataset默认提供了FixedLenFeature,VarLenFeature,FixedLenSequenceFeature等。

FixedLenFeature() 函数有三个参数:

  • shape:输入数据的shape。
  • dtype:输入的数据类型。
  • default_value:如果示例缺少此功能,则使用该值。它必须与dtype和指定shape兼容。

代码注释:

主要参考:
TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
tensorflow学习笔记——高效读取数据的方法(TFRecord

posted @ 2020-12-13 10:10  开普勒醒醒吧  阅读(216)  评论(2编辑  收藏  举报