读取tfrecord 代码——可用任意照片均可2
代码
1 # -*- coding: utf-8 -*- 2 # @Time : 2018/12/1 11:06 3 # @Author : MaochengHu 4 # @Email : wojiaohumaocheng@gmail.com 5 # @File : read_tfrecord.py 6 # @Software: PyCharm 7 import os 8 import tensorflow as tf 9 flags = tf.app.flags 10 flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record', 11 'path to tfrecord file') 12 flags.DEFINE_integer('resize_height', 800, 'resize height of image') 13 flags.DEFINE_integer('resize_width', 800, 'resize width of image') 14 FLAG = flags.FLAGS 15 slim = tf.contrib.slim 16 17 def print_data(image, resized_image, label, height, width): 18 with tf.Session() as sess: 19 init_op = tf.global_variables_initializer() 20 sess.run(init_op) 21 coord = tf.train.Coordinator() 22 threads = tf.train.start_queue_runners(coord=coord) 23 for i in range(20): 24 print("______________________image({})___________________".format(i)) 25 print_image, print_resized_image, print_label, print_height, print_width = sess.run( 26 [image, resized_image, label, height, width]) 27 print("resized_image shape is: ", print_resized_image.shape) 28 print("image shape is: ", print_image.shape) 29 print("image label is: ", print_label) 30 print("image height is: ", print_height) 31 print("image width is: ", print_width) 32 coord.request_stop() 33 coord.join(threads) 34 35 def reshape_same_size(image, output_height, output_width): 36 """Resize images by fixed sides. 37 38 Args: 39 image: A 3-D image `Tensor`. 40 output_height: The height of the image after preprocessing. 41 output_width: The width of the image after preprocessing. 42 43 Returns: 44 resized_image: A 3-D tensor containing the resized image. 45 """ 46 output_height = tf.convert_to_tensor(output_height, dtype=tf.int32) 47 output_width = tf.convert_to_tensor(output_width, dtype=tf.int32) 48 49 image = tf.expand_dims(image, 0) 50 resized_image = tf.image.resize_nearest_neighbor( 51 image, [output_height, output_width], align_corners=False) 52 resized_image = tf.squeeze(resized_image) 53 return resized_image 54 55 def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800): 56 keys_to_features = { 57 'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string, ), 58 'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string), 59 'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0), 60 'image/height': tf.FixedLenFeature([], tf.int64, default_value=0), 61 'image/width': tf.FixedLenFeature([], tf.int64, default_value=0) 62 } 63 64 items_to_handlers = { 65 'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3), 66 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 67 'height': slim.tfexample_decoder.Tensor('image/height', shape=[]), 68 'width': slim.tfexample_decoder.Tensor('image/width', shape=[]) 69 } 70 decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 71 72 labels_to_names = None 73 items_to_descriptions = { 74 'image': 'An image with shape image_shape.', 75 'label': 'A single integer between 0 and 9.'} 76 77 dataset = slim.dataset.Dataset( 78 data_sources=tfrecord_path, 79 reader=tf.TFRecordReader, 80 decoder=decoder, 81 num_samples=num_samples, 82 items_to_descriptions=None, 83 num_classes=num_classes, 84 ) 85 provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset, 86 num_readers=3, 87 shuffle=True, 88 common_queue_capacity=256, 89 common_queue_min=128, 90 seed=None) 91 image, label, height, width = provider.get(['image', 'label', 'height', 'width']) 92 resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width])) 93 return resized_image, label, image, height, width 94 95 if __name__ == '__main__': 96 resized_image, label, image, height, width = read_tfrecord(tfrecord_path='train.record', 97 resize_height=800, 98 resize_width=800) 99 # resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width) 100 # resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width])) 101 print_data(image, resized_image, label, height, width) 102 103 init_g = tf.global_variables_initializer() 104 init_l = tf.local_variables_initializer() 105 with tf.Session() as sess: 106 sess.run(init_g) 107 sess.run(init_l) 108 tf.train.start_queue_runners(sess) 109 print("SDDFA") 110 trX = image.eval(session=sess) 111 trY = label.eval(session=sess) 112 print("AA") 113 print(trX.shape)