tensorflow tfrecoder read write
1 # write in tfrecord 2 import tensorflow as tf 3 import os 4 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 5 6 7 FLAGS = tf.app.flags.FLAGS 8 tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件") 9 tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径") 10 tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类") 11 12 13 def dealwithlabel(label_str): 14 15 # 构建字符索引 {0:'A', 1:'B'......} 16 num_letter = dict(enumerate(list(FLAGS.letter))) 17 18 # 键值对反转 {'A':0, 'B':1......} 19 letter_num = dict(zip(num_letter.values(), num_letter.keys())) 20 21 print(letter_num) 22 23 # 构建标签的列表 24 array = [] 25 26 # 给标签数据进行处理[[b"NZPP"]......] 27 for string in label_str: 28 29 letter_list = []# [1,2,3,4] 30 31 # 修改编码,bytes --> string 32 for letter in string.decode('utf-8'): 33 letter_list.append(letter_num[letter]) 34 35 array.append(letter_list) 36 37 # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10], [1, 0, 8, 17], [0, 9, 24, 14].....] 38 print(array) 39 40 # 将array转换成tensor类型 41 label = tf.constant(array) 42 43 return label 44 45 46 def get_captcha_image(): 47 """ 48 获取验证码图片数据 49 :param file_list: 路径+文件名列表 50 :return: image 51 """ 52 # 构造文件名 53 filename = [] 54 55 for i in range(6000): 56 string = str(i) + ".jpg" 57 filename.append(string) 58 59 # 构造路径+文件 60 file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename] 61 62 # 构造文件队列 63 file_queue = tf.train.string_input_producer(file_list, shuffle=False) 64 65 # 构造阅读器 66 reader = tf.WholeFileReader() 67 68 # 读取图片数据内容 69 key, value = reader.read(file_queue) 70 71 # 解码图片数据 72 image = tf.image.decode_jpeg(value) 73 74 image.set_shape([20, 80, 3]) 75 76 # 批处理数据 [6000, 20, 80, 3] 77 image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000) 78 79 return image_batch 80 81 82 def get_captcha_label(): 83 """ 84 读取验证码图片标签数据 85 :return: label 86 """ 87 file_queue = tf.train.string_input_producer(["../data/Genpics/labels.csv"], shuffle=False) 88 89 reader = tf.TextLineReader() 90 91 key, value = reader.read(file_queue) 92 93 records = [[1], ["None"]] 94 95 number, label = tf.decode_csv(value, record_defaults=records) 96 97 # [["NZPP"], ["WKHK"], ["ASDY"]] 98 label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000) 99 100 return label_batch 101 102 103 def write_to_tfrecords(image_batch, label_batch): 104 """ 105 将图片内容和标签写入到tfrecords文件当中 106 :param image_batch: 特征值 107 :param label_batch: 标签纸 108 :return: None 109 """ 110 # 转换类型 111 label_batch = tf.cast(label_batch, tf.uint8) 112 113 print(label_batch) 114 115 # 建立TFRecords 存储器 116 writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir) 117 118 # 循环将每一个图片上的数据构造example协议块,序列化后写入 119 for i in range(6000): 120 # 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式 121 image_string = image_batch[i].eval().tostring() 122 123 # 标签值,转换成整型 124 label_string = label_batch[i].eval().tostring() 125 126 # 构造协议块 127 example = tf.train.Example(features=tf.train.Features(feature={ 128 "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])), 129 "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string])) 130 })) 131 132 writer.write(example.SerializeToString()) 133 134 # 关闭文件 135 writer.close() 136 137 return None 138 139 140 if __name__ == "__main__": 141 142 # 获取验证码文件当中的图片 143 image_batch = get_captcha_image() 144 145 # 获取验证码文件当中的标签数据 146 label = get_captcha_label() 147 148 print(image_batch, label) 149 150 with tf.Session() as sess: 151 152 coord = tf.train.Coordinator() 153 154 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 155 156 # 获取tensor里面的值 157 label_str = sess.run(label) 158 159 print(label_str) 160 161 # 处理字符串标签到数字张量 162 label_batch = dealwithlabel(label_str) 163 164 print(label_batch) 165 166 # 将图片数据和内容写入到tfrecords文件当中 167 write_to_tfrecords(image_batch, label_batch) 168 169 coord.request_stop() 170 171 coord.join(threads)
1 # read tfrecords 2 def read_and_decode(): 3 """ 4 读取验证码数据API 5 :return: image_batch, label_batch 6 """ 7 # 1、构建文件队列 8 file_queue = tf.train.string_input_producer([FLAGS.captcha_dir]) 9 10 # 2、构建阅读器,读取文件内容,默认一个样本 11 reader = tf.TFRecordReader() 12 13 # 读取内容 14 key, value = reader.read(file_queue) 15 16 # tfrecords格式example,需要解析 17 features = tf.parse_single_example(value, features={ 18 "image": tf.FixedLenFeature([], tf.string), 19 "label": tf.FixedLenFeature([], tf.string), 20 }) 21 22 # 解码内容,字符串内容 23 # 1、先解析图片的特征值 24 image = tf.decode_raw(features["image"], tf.uint8) 25 # 1、先解析图片的目标值 26 label = tf.decode_raw(features["label"], tf.uint8) 27 28 # print(image, label) 29 30 # 改变形状 31 image_reshape = tf.reshape(image, [20, 80, 3]) 32 33 label_reshape = tf.reshape(label, [4]) 34 35 print(image_reshape, label_reshape) 36 37 # 进行批处理,每批次读取的样本数 100, 也就是每次训练时候的样本 38 image_batch, label_btach = tf.train.batch([image_reshape, label_reshape], batch_size=FLAGS.batch_size, num_threads=1, capacity=FLAGS.batch_size) 39 40 print(image_batch, label_btach) 41 return image_batch, label_btach
# write flags FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件") tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径") tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类") # read flags tf.app.flags.DEFINE_string("captcha_dir", "./tfrecords/captcha.tfrecords", "验证码数据的路径") tf.app.flags.DEFINE_integer("batch_size", 100, "每批次训练的样本数") tf.app.flags.DEFINE_integer("label_num", 4, "每个样本的目标值数量") tf.app.flags.DEFINE_integer("letter_num", 26, "每个目标值取的字母的可能心个数")