tensorflow2.0-------------读取二进制文件和tfrecords

#-*- coding: utf-8 -*-
# coding:unicode_escape
#@Time : 2021/3/3 15:10
#@Author : 杨晓
#@File : binary_read.py
#@Software: PyCharm
import tensorflow as tf
import os
tf.compat.v1.disable_eager_execution()

class Cifar(object):


    def __init__(self):
        self.height = 32
        self.width = 32
        self.channels = 3
        self.image_bytes = self.height * self.width * self.channels
        self.label_bytes = 1
        self.all_bytes = self.image_bytes + self.label_bytes
    # 读取二进制文件
    def read_binary(self):
        file_name = os.listdir("../tmp/data/cifar-10-batches-bin")
        # # 构造文件名列表
        file_list = [os.path.join("../tmp/data/cifar-10-batches-bin",file) for file in file_name if file[-3:] == "bin"]
        # 构造文件队列
        file_queue = tf.compat.v1.train.string_input_producer(file_list)
        # 读取并解码
        # 读取
        reader = tf.compat.v1.FixedLengthRecordReader(self.all_bytes)
        key,value = reader.read(file_queue)
        # 解码
        decoded = tf.compat.v1.decode_raw(value,tf.uint8)
        # 将目标值和特征值切片

        label = tf.slice(decoded,[0],[self.label_bytes])
        image = tf.slice(decoded,[self.label_bytes],[self.image_bytes])
        # 调整图片形状 Tensor("Reshape:0", shape=(3, 32, 32), dtype=uint8)
        image_reshape = tf.reshape(image,shape=[self.channels,self.height,self.width])
        # 将图片的顺序转换为 height width channels
        image_transpose = tf.transpose(image_reshape,[1,2,0])
        # 调整图片类型
        # image_cast = tf.cast(image_transpose,tf.float32)
        # 批处理
        label_batch,image_batch = tf.compat.v1.train.batch([label,image_transpose],batch_size=100,num_threads=1,capacity=100)
        print("image_bath:\n",image_batch)
        # 开启会话
        with tf.compat.v1.Session() as sess:
            # 开启线程管理器
            coord = tf.compat.v1.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess,coord=coord)
            label_value,image_value, = sess.run([label_batch,image_batch])
            print("label_new:\n",label_value)
            print("image_new:\n",image_value)
            # 回收子线程
            coord.request_stop()
            coord.join(threads=threads)
        return image_value,label_value


    def write_to_tfrecords(self,image_batch,label_batch):
        '''
        将样本的特征值和目标值写入rfrecords
        :return:
        '''
        with tf.compat.v1.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                print("records_image:\n",image)
                print("records_label:\n",label)
                example = tf.compat.v1.train.Example(features=tf.compat.v1.train.Features(feature={
                    "image": tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[image])),
                    "label": tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())
        return None

    def read_tfrecords(self):
        # 构造文件队列
        file_queue = tf.compat.v1.train.string_input_producer(["cifar10.tfrecords"])

        # 读取与解码
        reader = tf.compat.v1.TFRecordReader()
        key,value = reader.read(file_queue)
        print("key:\n",key)
        print("value:\n",value)
        # 解析example
        # 解析example
        feature = tf.compat.v1.parse_single_example(value, features={
            "image": tf.compat.v1.FixedLenFeature([], tf.string),
            "label": tf.compat.v1.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n",image)
        print("read_tf_label:\n",label)

        # 解码
        image_decode = tf.compat.v1.decode_raw(image,tf.uint8)
        print("image_decode:\n",image_decode)
        # 调整形状
        image_reshape = tf.reshape(image_decode,[self.height,self.width,self.channels])
        print("image_reshape:\n",image_reshape)
        # 批处理构造队列
        image_batch, label_batch = tf.compat.v1.train.batch([image_reshape, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:\n",image_batch)
        print("label_batch:\n",label_batch)
        with tf.compat.v1.Session() as sess:
            # 开启线程管理器
            coord = tf.compat.v1.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
            image_value, label_value = sess.run([image_batch, label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)
        return None

if __name__ == '__main__':
    #获取文件名
    cifar = Cifar()
    # image_value,label_value = cifar.read_binary()
    # cifar.write_to_tfrecords(image_value,label_value)
    cifar.read_tfrecords()
posted @ 2021-03-03 19:51  littlemelon  阅读(342)  评论(0编辑  收藏  举报