制作tfrecord 代码——可用任意照片均可
代码:
1 # -*- coding: utf-8 -*- 2 # @Time : 2018/11/23 0:09 3 # @Author : MaochengHu 4 # @Email : wojiaohumaocheng@gmail.com 5 # @File : generate_tfrecord.py 6 # @Software: PyCharm 7 8 import os 9 import tensorflow as tf 10 import io 11 from PIL import Image 12 import json 13 def get_annotation_dict(input_folder_path, word2number_dict): 14 label_dict = {} 15 father_file_list = os.listdir(input_folder_path) 16 for father_file in father_file_list: 17 full_father_file = os.path.join(input_folder_path, father_file) 18 son_file_list = os.listdir(full_father_file) 19 for image_name in son_file_list: 20 label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file] 21 return label_dict 22 23 24 def save_json(label_dict, json_path): 25 with open(json_path, 'w') as json_path: 26 json.dump(label_dict, json_path) 27 print("label json file has been generated successfully!") 28 29 30 31 def int64_feature(value): 32 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 33 34 35 def bytes_feature(value): 36 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 37 38 39 def process_image_channels(image): 40 process_flag = False 41 # process the 4 channels .png 42 if image.mode == 'RGBA': 43 r, g, b, a = image.split() 44 image = Image.merge("RGB", (r,g,b)) 45 process_flag = True 46 # process the channel image 47 elif image.mode != 'RGB': 48 image = image.convert("RGB") 49 process_flag = True 50 return image, process_flag 51 52 53 def process_image_reshape(image, resize): 54 width, height = image.size 55 if resize is not None: 56 if width > height: 57 width = int(width * resize / height) 58 height = resize 59 else: 60 width = resize 61 height = int(height * resize / width) 62 image = image.resize((width, height), Image.ANTIALIAS) 63 return image 64 65 66 def create_tf_example(image_path, label, resize=None): 67 with tf.gfile.GFile(image_path, 'rb') as fid: 68 encode_jpg = fid.read() 69 encode_jpg_io = io.BytesIO(encode_jpg) 70 image = Image.open(encode_jpg_io) 71 # process png pic with four channels 72 image, process_flag = process_image_channels(image) 73 # reshape image 74 image = process_image_reshape(image, resize) 75 if process_flag == True or resize is not None: 76 bytes_io = io.BytesIO() 77 image.save(bytes_io, format='JPEG') 78 encoded_jpg = bytes_io.getvalue() 79 width, height = image.size 80 tf_example = tf.train.Example( 81 features=tf.train.Features( 82 feature={ 83 'image/encoded': bytes_feature(encode_jpg), 84 'image/format': bytes_feature(b'jpg'), 85 'image/class/label': int64_feature(label), 86 'image/height': int64_feature(height), 87 'image/width': int64_feature(width) 88 } 89 )) 90 return tf_example 91 92 93 def generate_tfrecord(annotation_dict, record_path, resize=None): 94 num_tf_example = 0 95 writer = tf.io.TFRecordWriter(record_path) 96 for image_path, label in annotation_dict.items(): 97 if not tf.gfile.GFile(image_path): 98 print("{} does not exist".format(image_path)) 99 tf_example = create_tf_example(image_path, label, resize) 100 writer.write(tf_example.SerializeToString()) 101 num_tf_example += 1 102 if num_tf_example % 100 == 0: 103 print("Create %d TF_Example" % num_tf_example) 104 writer.close() 105 print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path)) 106 107 108 109 110 if __name__ == '__main__': 111 word2number_dict = { 112 "combinations": 0, 113 "details": 1, 114 "sizes": 2, 115 "tags": 3, 116 "models": 4, 117 "tileds": 5, 118 "hangs": 6 119 } 120 images_dir = '../images_root' 121 #annotation_path = FLAGS.annotation_path 122 record_path = 'train.record' 123 annotation_dict = get_annotation_dict(images_dir, word2number_dict) 124 print(annotation_dict) 125 print("AAA") 126 generate_tfrecord(annotation_dict, record_path)