TFRecord及PAI上的第一个程序
TFRcord的原理
TFRecord是一种标准的Tensorflow格式,可以将任意的数据转换为TFRecord格式, 这种格式与网络应用架构相匹配,多线程的并行处理数据,速度快。TFRecords文件包含了tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
)。可以 将数据填入到Example
协议内存块(protocol buffer),并将协议内存块序列化为一个字符串, 通过tf.python_io.TFRecordWriter
class写入到TFRecords文件。
读取TFRecords文件的数据, 使用tf.TFRecordReader
的tf.parse_single_example
解析器。parse_single_example
将Example
协议内存块(protocol buffer)解析为张量。
MNIST数据集转化为TFRcord并进行读取
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np from PIL import Image #把传入的value转化为整数型的属性,int64_list对应着 tf.train.Example 的定义 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) #把传入的value转化为字符串型的属性,bytes_list对应着 tf.train.Example 的定义 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def mnsit_tfreords(images,labels,filename,num_examples): writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): #把图像转化为字符串 image_raw = images[index].tostring() #将一个图像转化为Example Protocol Buffer,并将所有的信息写入这个数据结构 example = tf.train.Example(features=tf.train.Features(feature={ 'image_raw': _bytes_feature(image_raw), 'label': _int64_feature(np.argmax(labels[index]))})) writer.write(example.SerializeToString()) writer.close() def read_image(filename): reader = tf.TFRecordReader() #通过 tf.train.string_input_producer 创建输入队列 filename_queue = tf.train.string_input_producer([filename]) #从文件中读取一个样例 _, serialized_example = reader.read(filename_queue) #解析读入的一个样例 features = tf.parse_single_example( serialized_example, features={ #这里解析数据的格式需要和上面程序写入数据的格式一致 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) #tf.decode_raw可以将字符串解析成图像对应的像素数组 image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.reshape(image, [28, 28, 1]) #tf.cast可以将传入的数据转化为想要改成的数据类型 label = tf.cast(features['label'], tf.int32) return image,label def read_image_batch(filename): image,label = read_image(filename) num_preprocess_threads = 1 batch_size = 128 min_queue_examples = 100 image_batch, label_batch = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + batch_size, min_after_dequeue=min_queue_examples) return image_batch,label_batch #读取TFRecord文件中的数据 def read_tfrecords(filename): image_batch,label_batch = read_image_batch(filename) with tf.Session() as sess: tf.global_variables_initializer().run() #启动多线程处理数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(5): image, label = sess.run([image_batch, label_batch]) result = Image.fromarray(image[0].reshape([28,28])) result.save(str(i) + '.png') coord.request_stop() coord.join(threads) if __name__ == '__main__': mnist = input_data.read_data_sets("D:/Git/MyCode/GAN/Datas/mnist", dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels num_examples = mnist.train.num_examples filename = "./train.tfrecords" mnsit_tfreords(images,labels,filename,num_examples) images = mnist.test.images labels = mnist.test.labels num_examples = mnist.test.num_examples filename = "./test.tfrecords" mnsit_tfreords(images,labels,filename,num_examples) read_tfrecords("./test.tfrecords")
其中tf.train.batch是按顺序读取数据,队列中的数据始终是一个有序的队列,对头一直在补充,而tf.train.shuffle_batch是将队列中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,capacity队列长度,读取的数据是基于这个范围的,在这个范围内,min_after_dequeue越大,数据越乱。
PAI第一个程序Mnist分类
将mnist数据集保存为tensorflow的标准形式,在PAI的OSS存储中也可以直接盗用tensorflow进行读取,比较方便。
import os import tensorflow as tf import argparse FLAGS = None; def read_image(filename): reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer([filename]) _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.reshape(image, [784]) image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return image,label def read_image_batch(filename,batch_size = 128): image,label = read_image(filename) num_preprocess_threads = 10 batch_size = batch_size min_queue_examples = 100 image_batch, label_batch = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + batch_size, min_after_dequeue=min_queue_examples) one_hot_labels = tf.to_float(tf.one_hot(label_batch, 10, 1, 0)) return image_batch,one_hot_labels def main(_): train_file_path = os.path.join(FLAGS.buckets, "train.tfrecords") test_file_path = os.path.join(FLAGS.buckets, "test.tfrecords") ckpt_path = os.path.join(FLAGS.checkpointDir, "model.ckpt") train_images,train_labels = read_image_batch(train_file_path) test_images,test_labels = read_image_batch(test_file_path) W = tf.get_variable('weights', [784, 10],initializer = tf.random_normal_initializer(stddev = 0.02)) B = tf.get_variable('biases', [10],initializer = tf.constant_initializer(0.0)) x = tf.reshape(train_images,[-1,784]) y = tf.to_float(train_labels) y_ = tf.matmul(x, W) + B cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) x_test = tf.reshape(test_images, [-1, 784]) y_pred = tf.matmul(x_test, W) + B y_test = tf.to_float(test_labels) correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_test, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess: tf.global_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(1000): sess.run(train_step) if ((i+1) % 10 == 0): print("step:", i + 1, "accuracy:", sess.run(accuracy)) print("accuracy: " , sess.run(accuracy)) save_path = saver.save(sess, ckpt_path) print("Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads) if __name__ == '__main__': parser = argparse.ArgumentParser(); parser.add_argument('--buckets', type=str, default='',help='input data path') parser.add_argument('--checkpointDir', type=str, default='',help='output model path') FLAGS, _ = parser.parse_known_args() tf.app.run(main=main)
*踩过的坑,PAI不能有中文,没有在本地调好在运行,还有tf.train.match_filenames_once这个函数去读取tfrecords文件,需要本地变量保存filenames,所以不能存在零时变量上,不然无法读取文件,可以之间使用不进行局部存储。