使用tfrecord建立自己的数据集

注意事项:

1.关于输入图像格式的问题

    使用io.imread()的时,根据输入图像确定as_grey的参数值。 转化为字符串之后(image.tostring) ,最后输出看下image_raw的长度。因为不同的图像编码格式,存储方式不同。

   我读入的灰度图jpeg格式,类型是int64,image_raw的大小是图像的大小的8倍 。 但如果是RGB图像,则统一类型是uint8。确定了类型,在之后的解码 (decode_raw)中,需要将type设置和存储方式同样的类型。 

   根据image_raw的长度和原图像大小,推算一下使用的类型,常用的是uint8,int32,int64.

2.转化成tfrecords的时间有点长,需要等待。

import os
import tensorflow as tf
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
def get_data (file_path):
    data = []
    label = []
    for dirs in os.listdir(file_path):
        temp_path = os.path.join(file_path,dirs)
        i =0
        for dirss in os.listdir(temp_path):
            data.append(os.path.join(temp_path,dirss))
        num_img = len(os.listdir(temp_path))
        label = np.append(label,num_img*[1])
    temp = np.array([data,label])
    temp = temp.transpose()
    np.random.shuffle(temp)
    image_list = list(temp[:,0])
    label_list = list(temp[:,1])
    label_list = [int(float(i)) for i in label_list]
    return image_list,label_list
# 转化成字符串
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_tfrecord(images,labels,save_filename):
    writer = tf.python_io.TFRecordWriter(save_filename)
    print("Transform start....")
    num_examples= len(labels)
    if np.shape(images)[0]!=num_examples:
        raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], num_examples))
    for index in np.arange(0,num_examples):
        try:
            image = io.imread(images[index],as_grey=False)
            #image = tf.image.decode_jpeg(images[index])
            #print(image.shape)
            image_raw = image.tostring()
            #print(len(image_raw))
            example = tf.train.Example(features = tf.train.Features(feature={
                'label' :_int64_feature(int(labels[index])),
                'image_raw':_bytes_feature(image_raw)
            }))
            writer.write(example.SerializeToString())
        except IOError as e:
            print('Could not read:',images[index])
            print('error :%s Skip it !\n'%e)
    writer.close()
    print("success!")

def read_and_decode(tfrecords_file,batch_size):
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    _,serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([],tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string)
        }
    )
    #print(features['image_raw'])
    capacity = 1000+3*batch_size
    image = tf.decode_raw(features['image_raw'],tf.uint8)
    label = tf.cast(features['label'],tf.int32)
    #image = tf.image.resize_images(image,[300, 200, 1])
    image = tf.reshape(image,[200,300,3])
    image_batch,label_batch = tf.train.batch([image,label],
                                             batch_size=batch_size,
                                             capacity=capacity)
    image_batch = tf.image.resize_image_with_crop_or_pad(image_batch,100,100)
    image_batch = tf.cast(image_batch, tf.float32) * (1. / 255)
    return image_batch,label_batch
def plot_images(images, labels):
    '''plot one batch size
    '''
    for i in np.arange(0, 2):
        plt.subplot(3, 3, i + 1)
        plt.axis('off')
        # plt.title((labels[i] - 1), fontsize = 14)
        plt.subplots_adjust(top=1)
        print(labels[i])
        print(images.shape)
        # print(images[i].shape)
        plt.imshow(images[i][:,:,:])
    plt.show()
def train():
    image,label = get_data('E:\syn_data')
    convert_tfrecord(image,label,'1.tfrecords')
    x_batch, y_batch = read_and_decode('1.tfrecords', batch_size=2)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            i=0
            while not coord.should_stop() and i<3:
                     # just plot one batch size
                image, label = sess.run([x_batch, y_batch])
                plot_images(image, label)
                i+=1
        except tf.errors.OutOfRangeError:
            print('done!')
        finally:
            coord.request_stop()
        coord.join(threads)

#train()

 

posted @ 2017-12-18 19:24  Cheney_1016  阅读(1257)  评论(0编辑  收藏  举报