4.9 TF读入TFRecord

import tensorflow as tf

filelist = ['data/train.tfrecord']
file_queue = tf.train.string_input_producer(filelist,  # 定义文件队列
                                            num_epochs=None,
                                            shuffle=True)
reader = tf.TFRecordReader()  # tensoeflow文件读取器从文件队列读取
_, ex = reader.read(file_queue)  # 原图-编码-序列化-打包,现在是反
# 向解析,ex是序列化之后的数据,所以还需要解码

feature = {  # 定义序列化格式
    'image': tf.FixedLenFeature([], tf.string),  # image是byte储存的,解码则直接解析为string型
    'label': tf.FixedLenFeature([], tf.int64)  # label本身就是int型
}
# 将队列中数据打乱后再读取出来
# batch_size:从队列中提取新的批量大小.
# capacity:队列容量.
# min_after_dequeue:最小队列容量.
batchsize = 2
batch = tf.train.shuffle_batch([ex], batchsize, capacity=batchsize * 10,
                               min_after_dequeue=batchsize * 5)

# 解码方法,features(字典型)有点像解析格式,返回的是字典型
example = tf.parse_example(batch, features=feature)
image = example['image']
label = example['label']

#image是string型,需要转换为uint8
image=tf.decode_raw(image, tf.uint8)

#这里的image其实是一串数字,按我们32*32*3的数据规
#模来重排序,可以制定这样的矩阵的大小,-1表示程序自动计算矩阵个数
#输出image:Tensor("DecodeRaw:0", shape=(2, ?), dtype=uint8)
image = tf.reshape(image, [-1,32, 32, 3])
#输出image:Tensor("Reshape:0", shape=(?, 32, 32, 3), dtype=uint8)


with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    tf.train.start_queue_runners(sess=sess)
    for i in range(1):
        image_bth,label=sess.run([image,label])
        import cv2
        cv2.imshow(str(label[0,...]),image_bth[0,...])
        cv2.waitKey(0)
posted @ 2020-05-12 15:38  盐亭的森林  阅读(140)  评论(0编辑  收藏  举报