day21 TFRecord格式转换MNIST并显示
首先简要介绍了下TFRecord格式以及内部实现protobuf协议,然后基于TFRecord格式,对MNIST数据集转换成TFRecord格式,写入本地磁盘文件,再从磁盘文件读取,通过pyplot模块现实在界面上,效果图如下:
TFRecord和Protobuf协议简介
TFRecord是谷歌专门为Tensorflow打造的一种存储格式,基于protobuf协议实现,也是谷歌推荐的,一个主要原因是做到训练和验证数据格式的统一,有助于不同开发者快速迁移模型。Google Protocol Buffer( 简称 Protobuf) 是 Google 公司内部的混合语言数据标准,目前已经正在使用的有超过 48,162 种报文格式定义和超过 12,183 个 .proto 文件。他们用于 RPC 系统和持续数据存储系统。Protocol Buffers 是一种轻便高效的结构化数据存储格式,可以用于结构化数据串行化,或者说序列化。它很适合做数据存储或 RPC 数据交换格式。可用于通讯协议、数据存储等领域的语言无关、平台无关、可扩展的序列化结构数据格式。目前提供了 C++、Java、Python 三种语言的 API。
|
优点 |
缺点 |
Protobuf |
1、Protobuf 有如 XML,不过它更小、更快(几十倍于XML和JOSON)、也更简单 2、“向后”兼容性好 3、 Protobuf 语义更清晰,无需类似 XML 解析器的东西 4、使用 Protobuf 无需学习复杂的文档对象模型 |
1、 功能简单,无法用来表示复杂的概念 2、Protobuf 只是 Google 公司内部使用的工具,在通用性上还差很多 3、由于文本并不适合用来描述数据结构,所以 Protobuf 也不适合用来对基于文本的标记文档(如 HTML) 4、除非你有 .proto 定义,否则你没法直接读出 Protobuf 的任何内容 |
下面举个简单的例子,从数据的存储格式的角度进行对比,假如要存储一个键值对:{price:150}
protobuf的表示方式如下,protobuf的物理存储:08 96 01,就3个字节。采用key-value的方式存放,第一个字节是key,它是field_number << 3 | wire_type构成。所以field number是1,wire type是0,即varint,有了这个wire type就可以用来解析96 01了。
message Test { optional int32 price = 1; }
xml的存储表示如下,大约需要36字节。
<some> <name>price</name> <value>150</value> </some>
json的存储表示如下,大约需要11字节。
{price:150}
综上所述,protobuf相比于json和xml,对象序列化时可以节省非常大的空间,从而带来非常快的传输速度。
利用TFRecord格式存储、读取和现实MNIST数据集
在TensorFlow中,TFRecord格式是通过tf.train.Example Protocol Buffer协议的存储的,以下代码给出了tf.train.Example的定义:
message Example { Features features = 1; }; message Features { // Map from feature name to feature. map<string, Feature> feature = 1; }; message Feature { // Each feature can be exactly one kind. oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } }; message BytesList { repeated bytes value = 1; } message FloatList { repeated float value = 1 [packed = true]; } message Int64List { repeated int64 value = 1 [packed = true]; }
下面给出两个代码实例:一个程序ToTFRecord.py从MNIST数据集中读取图像和标签集,然后通过TFRecord格式文件中,另一个程序FromTFRecord.py从文件中读取TFRecord格式图像,然后通过pylot模块显示在界面上。
ToTFRecord.py:
import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets import mnist import numpy as np def __int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value=[value])) def __bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value])) # 读取图像和标签 mnist_data = mnist.read_data_sets(train_dir='MNIST_data/',dtype=tf.uint8,one_hot=True) images = mnist_data.train.images labels = mnist_data.train.labels pixels = images.shape[1] num_examples = mnist_data.train.num_examples filename = "MNIST_TFRecord/output.tfrecords" writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): image_raw = images[index].tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'pixels':__int64_feature(pixels), 'label':__int64_feature(np.argmax(labels[index])), 'image_raw':__bytes_feature(image_raw) })) writer.write(example.SerializeToString()) writer.close()
FromTFRecord.py:
import tensorflow as tf # 读取一个样例 reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer(["MNIST_TFRecord/output.tfrecords"]) _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized=serialized_example,features={ 'image_raw':tf.FixedLenFeature([],tf.string), 'pixels':tf.FixedLenFeature([],tf.int64), 'label':tf.FixedLenFeature([],tf.int64) }) # 从样例中解析数据 images = tf.decode_raw(features['image_raw'],tf.uint8) labels = tf.cast(features['label'],tf.int32) pixels = tf.cast(features['pixels'],tf.int32) from matplotlib import pyplot as plt import time import datetime from six.moves import xrange # pylint: disable=redefined-builtin fig, ax = plt.subplots(2, 5, figsize=[3, 3]) plt.ion() plt.axis('off') print("%s :创建10个窗口成功..."%(datetime.datetime.now())) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: for step in xrange(10): if coord.should_stop(): break for i in range(2): for j in range(5): cur_pic = i * 5 + j image, label, pixel = sess.run([images, labels, pixels]) image = image.reshape([28, 28]) print(image, label, pixel) ax[i, j].imshow(image, cmap=plt.cm.gray) plt.show() plt.pause(2) except Exception: # Report exceptions to the coordinator. coord.request_stop() # Terminate as usual. It is innocuous to request stop twice. coord.request_stop() coord.join(threads)