制作tfrecord 代码——可用任意照片均可

代码:

  1 # -*- coding: utf-8 -*-
  2 # @Time    : 2018/11/23 0:09
  3 # @Author  : MaochengHu
  4 # @Email   : wojiaohumaocheng@gmail.com
  5 # @File    : generate_tfrecord.py
  6 # @Software: PyCharm
  7 
  8 import os
  9 import tensorflow as tf
 10 import io
 11 from PIL import Image
 12 import json
 13 def get_annotation_dict(input_folder_path, word2number_dict):
 14     label_dict = {}
 15     father_file_list = os.listdir(input_folder_path)
 16     for father_file in father_file_list:
 17         full_father_file = os.path.join(input_folder_path, father_file)
 18         son_file_list = os.listdir(full_father_file)
 19         for image_name in son_file_list:
 20             label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file]
 21     return label_dict
 22 
 23 
 24 def save_json(label_dict, json_path):
 25     with open(json_path, 'w') as json_path:
 26         json.dump(label_dict, json_path)
 27     print("label json file has been generated successfully!")
 28 
 29 
 30 
 31 def int64_feature(value):
 32     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 33 
 34 
 35 def bytes_feature(value):
 36     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 37 
 38 
 39 def process_image_channels(image):
 40     process_flag = False
 41     # process the 4 channels .png
 42     if image.mode == 'RGBA':
 43         r, g, b, a = image.split()
 44         image = Image.merge("RGB", (r,g,b))
 45         process_flag = True
 46     # process the channel image
 47     elif image.mode != 'RGB':
 48         image = image.convert("RGB")
 49         process_flag = True
 50     return image, process_flag
 51 
 52 
 53 def process_image_reshape(image, resize):
 54     width, height = image.size
 55     if resize is not None:
 56         if width > height:
 57              width = int(width * resize / height)
 58              height = resize
 59         else:
 60             width = resize
 61             height = int(height * resize / width)
 62         image = image.resize((width, height), Image.ANTIALIAS)
 63     return image
 64 
 65 
 66 def create_tf_example(image_path, label, resize=None):
 67     with tf.gfile.GFile(image_path, 'rb') as fid:
 68         encode_jpg = fid.read()
 69     encode_jpg_io = io.BytesIO(encode_jpg)
 70     image = Image.open(encode_jpg_io)
 71     # process png pic with four channels
 72     image, process_flag = process_image_channels(image)
 73     # reshape image
 74     image = process_image_reshape(image, resize)
 75     if process_flag == True or resize is not None:
 76         bytes_io = io.BytesIO()
 77         image.save(bytes_io, format='JPEG')
 78         encoded_jpg = bytes_io.getvalue()
 79     width, height = image.size
 80     tf_example = tf.train.Example(
 81         features=tf.train.Features(
 82             feature={
 83                 'image/encoded': bytes_feature(encode_jpg),
 84                 'image/format': bytes_feature(b'jpg'),
 85                 'image/class/label': int64_feature(label),
 86                 'image/height': int64_feature(height),
 87                 'image/width': int64_feature(width)
 88             }
 89         ))
 90     return tf_example
 91 
 92 
 93 def generate_tfrecord(annotation_dict, record_path, resize=None):
 94     num_tf_example = 0
 95     writer = tf.io.TFRecordWriter(record_path)
 96     for image_path, label in annotation_dict.items():
 97         if not tf.gfile.GFile(image_path):
 98             print("{} does not exist".format(image_path))
 99         tf_example = create_tf_example(image_path, label, resize)
100         writer.write(tf_example.SerializeToString())
101         num_tf_example += 1
102         if num_tf_example % 100 == 0:
103             print("Create %d TF_Example" % num_tf_example)
104     writer.close()
105     print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path))
106 
107 
108 
109 
110 if __name__ == '__main__':
111     word2number_dict = {
112         "combinations": 0,
113         "details": 1,
114         "sizes": 2,
115         "tags": 3,
116         "models": 4,
117         "tileds": 5,
118         "hangs": 6
119     }
120     images_dir = '../images_root'
121     #annotation_path = FLAGS.annotation_path
122     record_path = 'train.record'
123     annotation_dict = get_annotation_dict(images_dir, word2number_dict)
124     print(annotation_dict)
125     print("AAA")
126     generate_tfrecord(annotation_dict, record_path)

 

posted @ 2020-03-07 22:21  博二爷  阅读(243)  评论(0编辑  收藏  举报