tfrecords转图片存储

import os
import shutil
import tensorflow as tf
import time
import sys
import cv2


# 图片存放位置
PATH_RES = [
    r'data_tfrecords/integers_tfrecords/train.tfrecords',
    r'data_tfrecords/integers_tfrecords/test.tfrecords',
    r'data_tfrecords/alphabets_tfrecords/train.tfrecords',
    r'data_tfrecords/alphabets_tfrecords/test.tfrecords',
    r'data_tfrecords/Chinese_letters_tfrecords/train.tfrecords',
    r'data_tfrecords/Chinese_letters_tfrecords/test.tfrecords'
    ]
PATH_DES = [
    r'imgs_from_tfrecords/integers/train/',
    r'imgs_from_tfrecords/integers/test/',
    r'imgs_from_tfrecords/alphabets/train/',
    r'imgs_from_tfrecords/alphabets/test/',
    r'imgs_from_tfrecords/Chinese_letters/train/',
    r'imgs_from_tfrecords/Chinese_letters/test/'
    ]

PATH = list(zip(PATH_RES, PATH_DES))


def tfrecord2jpg(path_res, path_des):
    print('tfrecords_files to be transformed:', path_res)
    reader = tf.TFRecordReader()
    start_time = int(time.time())
    prev_time = start_time
    idx = 0

    filename_queue = tf.train.string_input_producer([path_res], num_epochs=1)

    # 从 TFRecord 读取内容并保存到 serialized_example 中
    _, serialized_example = reader.read(filename_queue)
    # 读取 serialized_example 的格式
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })

    # 解析从 serialized_example 读取到的内容
    images = tf.decode_raw(features['image_raw'], tf.uint8)
    labels = tf.cast(features['label'], tf.int64)

    print('Extracting {} has just started.'.format(path_res))
    with tf.Session() as sess:
        # 启动多线程
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        while not coord.should_stop():
            try:
                label, img = sess.run([labels, images])
            except tf.errors.OutOfRangeError:
                print("Turn to next folder.")
                break
            cv2.imwrite(path_des+"_"+str(idx)+"_"+str(label)+'.jpg', img)
            idx += 1
            current_time = int(time.time())
            lasting_time = current_time - start_time
            interval_time = current_time - prev_time
            if interval_time >= 0.1:
                sys.stdout.flush()
                sys.stdout.write("\rGenerating the {}-th image: {},\
                                    lasting {} seconds".format(
                                    idx,
                                    path_des +
                                    str(idx) + '_' +
                                    str(label) + '.jpg',
                                    lasting_time))
                prev_time = current_time
        coord.request_stop()
        coord.join(threads)


def main():
    # get empty directory
    for i in range(len(PATH)):
        if os.path.isdir(PATH_DES[i]):
            if os.listdir(PATH_DES[i]):
                shutil.rmtree(PATH_DES[i])
                os.mkdir(PATH_DES[i])
        else:
            print(PATH_DES[i])
            os.mkdir(PATH_DES[i])
        tfrecord2jpg(PATH_RES[i], PATH_DES[i])


if __name__ == "__main__":
    main()
posted @ 2017-12-01 14:53  默盒  阅读(546)  评论(0编辑  收藏  举报