TFRecord及PAI上的第一个程序

TFRcord的原理

TFRecord是一种标准的Tensorflow格式,可以将任意的数据转换为TFRecord格式, 这种格式与网络应用架构相匹配,多线程的并行处理数据,速度快。TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。可以 将数据填入到Example协议内存块(protocol buffer),并将协议内存块序列化为一个字符串, 通过tf.python_io.TFRecordWriter class写入到TFRecords文件。

读取TFRecords文件的数据, 使用tf.TFRecordReadertf.parse_single_example解析器。parse_single_exampleExample协议内存块(protocol buffer)解析为张量。

 

MNIST数据集转化为TFRcord并进行读取

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from PIL import Image

#把传入的value转化为整数型的属性,int64_list对应着 tf.train.Example 的定义
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#把传入的value转化为字符串型的属性,bytes_list对应着 tf.train.Example 的定义
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def mnsit_tfreords(images,labels,filename,num_examples):
    
    writer = tf.python_io.TFRecordWriter(filename)

    for index in range(num_examples):
        #把图像转化为字符串
        image_raw = images[index].tostring()
        #将一个图像转化为Example Protocol Buffer,并将所有的信息写入这个数据结构
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'label': _int64_feature(np.argmax(labels[index]))}))
        writer.write(example.SerializeToString())

    writer.close()

def read_image(filename):
    reader = tf.TFRecordReader()
    #通过 tf.train.string_input_producer 创建输入队列
    filename_queue = tf.train.string_input_producer([filename])
    #从文件中读取一个样例
    _, serialized_example = reader.read(filename_queue)
    #解析读入的一个样例
    features = tf.parse_single_example(
        serialized_example,
        features={
            #这里解析数据的格式需要和上面程序写入数据的格式一致
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })
    #tf.decode_raw可以将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [28, 28, 1])
    #tf.cast可以将传入的数据转化为想要改成的数据类型
    label = tf.cast(features['label'], tf.int32)
    return image,label
def read_image_batch(filename):
    image,label = read_image(filename)
    num_preprocess_threads = 1
    batch_size = 128
    min_queue_examples = 100
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples +  batch_size,
        min_after_dequeue=min_queue_examples)
    return image_batch,label_batch

#读取TFRecord文件中的数据
def read_tfrecords(filename):
    image_batch,label_batch = read_image_batch(filename)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        #启动多线程处理数据
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(5):
            image, label = sess.run([image_batch, label_batch])
            result = Image.fromarray(image[0].reshape([28,28]))
            result.save(str(i) + '.png')
        coord.request_stop()
        coord.join(threads)
if __name__ == '__main__':
    mnist = input_data.read_data_sets("D:/Git/MyCode/GAN/Datas/mnist", dtype=tf.uint8, one_hot=True)
    images = mnist.train.images
    labels = mnist.train.labels
    num_examples = mnist.train.num_examples
    filename = "./train.tfrecords"
    mnsit_tfreords(images,labels,filename,num_examples)
    images = mnist.test.images
    labels = mnist.test.labels
    num_examples = mnist.test.num_examples
    filename = "./test.tfrecords"
    mnsit_tfreords(images,labels,filename,num_examples)
    read_tfrecords("./test.tfrecords")
View Code

其中tf.train.batch是按顺序读取数据,队列中的数据始终是一个有序的队列,对头一直在补充,而tf.train.shuffle_batch是将队列中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,capacity队列长度,读取的数据是基于这个范围的,在这个范围内,min_after_dequeue越大,数据越乱。

 

PAI第一个程序Mnist分类

将mnist数据集保存为tensorflow的标准形式,在PAI的OSS存储中也可以直接盗用tensorflow进行读取,比较方便。

import os
import tensorflow as tf
import argparse

FLAGS = None;


def read_image(filename):
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([filename])
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [784])
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)
    return image,label

def read_image_batch(filename,batch_size = 128):
    image,label = read_image(filename)
    num_preprocess_threads = 10
    batch_size = batch_size
    min_queue_examples = 100
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples +  batch_size,
        min_after_dequeue=min_queue_examples)
    one_hot_labels = tf.to_float(tf.one_hot(label_batch, 10, 1, 0))
    return image_batch,one_hot_labels

def main(_):
    train_file_path = os.path.join(FLAGS.buckets, "train.tfrecords")
    test_file_path = os.path.join(FLAGS.buckets, "test.tfrecords")
    ckpt_path = os.path.join(FLAGS.checkpointDir, "model.ckpt")

    train_images,train_labels = read_image_batch(train_file_path)
    test_images,test_labels = read_image_batch(test_file_path)

    W = tf.get_variable('weights', [784, 10],initializer = tf.random_normal_initializer(stddev = 0.02))
    B = tf.get_variable('biases', [10],initializer = tf.constant_initializer(0.0))

    x = tf.reshape(train_images,[-1,784])
    y = tf.to_float(train_labels)
    y_ = tf.matmul(x, W) + B
    
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    x_test = tf.reshape(test_images, [-1, 784])
    y_pred = tf.matmul(x_test, W) + B
    y_test = tf.to_float(test_labels)
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_test, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
        for i in range(1000):
            sess.run(train_step)
            if ((i+1) % 10 == 0):
                    print("step:", i + 1, "accuracy:", sess.run(accuracy))

        print("accuracy: " , sess.run(accuracy))

        save_path = saver.save(sess, ckpt_path)
        print("Model saved in file: %s" % save_path)

        coord.request_stop()
        coord.join(threads)
if __name__ == '__main__':
    parser = argparse.ArgumentParser();
    parser.add_argument('--buckets', type=str, default='',help='input data path')
    parser.add_argument('--checkpointDir', type=str, default='',help='output model path')
    FLAGS, _ = parser.parse_known_args()
    tf.app.run(main=main)
View Code

*踩过的坑,PAI不能有中文,没有在本地调好在运行,还有tf.train.match_filenames_once这个函数去读取tfrecords文件,需要本地变量保存filenames,所以不能存在零时变量上,不然无法读取文件,可以之间使用不进行局部存储。

posted @ 2018-03-17 20:34  雨婷墨染  阅读(470)  评论(1编辑  收藏  举报