import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 读取mnist训练数据。 mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data\\",dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels pixels = images.shape[1]#训练数据的图像分辨率,可作为Example的一个属性 print(pixels)
num_examples = mnist.train.num_examples print(type(num_examples)) print(num_examples)#训练图片的张数
print(type(images)) print(images[0].shape) print(images.shape)
print(type(labels)) print(labels[0].shape) print(labels.shape)
# 将数据转化为tf.train.Example格式。 def _make_example(pixels, label, image): image_raw = image.tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'pixels': _int64_feature(pixels), 'label': _int64_feature(np.argmax(label)), 'image_raw': _bytes_feature(image_raw) })) return example # 输出包含训练数据的TFRecord文件。 with tf.compat.v1.python_io.TFRecordWriter("F:\\output.tfrecords") as writer: for index in range(num_examples): example = _make_example(pixels, labels[index], images[index]) writer.write(example.SerializeToString()) print("TFRecord训练文件已保存。")