TensorFlow加载MNIST数据集¶
作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/
所用版本:python3.5.2,tensorflow1.8.0,tensorboard1.8.0
首先,在与Python代码相同路径下新建一个文件夹“MNIST_data”。
然后从MNIST数据集官网上http://yann.lecun.com/exdb/mnist/ 下载以下四个文件到“MNIST_data”文件夹中。
注意,不要解压,文件夹只保留这四个文件。
train-images-idx3-ubyte.gz: 训练集图片,包含55000张训练图片与5000张验证图片。
train-labels-idx1-ubyte.gz: 训练集图片对应的数字标签。
t10k-images-idx3-ubyte.gz: 测试集图片,包含10000张测试图片。
t10k-labels-idx1-ubyte.gz: 测试集图片对应的数字标签。
然后运行下面代码即可加载MNIST数据集。
In [1]:
# 导入TensorFlow中input_data.py文件
In [2]:
from tensorflow.examples.tutorials.mnist import input_data
In [3]:
# 从MNIST_data数据集中读取MNIST数据
In [4]:
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
In [5]:
# 进一步分析MNIST内容
In [6]:
# 加载数据
train_X = mnist.train.images #训练集样本
validation_X = mnist.validation.images #验证集样本
test_X = mnist.test.images #测试集样本
# 加载标签
train_Y = mnist.train.labels #训练集标签
validation_Y = mnist.validation.labels #验证集标签
test_Y = mnist.test.labels #测试集标签
In [7]:
print('训练集样本的大小:', train_X.shape)
print('训练集标签的大小:', train_Y.shape)
In [8]:
print('测试集样本的大小:', test_X.shape)
print('测试集标签的大小:', test_Y.shape)
In [9]:
print('验证集样本的大小:', validation_X.shape)
print('验证集标签的大小:', validation_Y.shape)
In [10]:
import matplotlib.pyplot as plt
In [11]:
# 显示出一张RGB图片看看
im = train_X[1]
im = im.reshape(-1, 28)
plt.imshow(im) # RGB图像
plt.show()
In [12]:
# 显示出一张灰度图片看看
im = train_X[1]
im = im.reshape(-1, 28)
plt.imshow(im,cmap='Greys')
plt.show()
In [13]:
#可视化样本,下面是输出了训练集中前20个样本
fig, ax = plt.subplots(nrows=4,ncols=5,sharex='all',sharey='all')
ax = ax.flatten()
for i in range(20):
img = train_X[i].reshape(28, 28)
ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
In [14]:
#查看数据,例如训练集中第一个样本的内容和标签
print(train_X[0]) #是一个包含784个元素且值在[0,1]之间的向量
print(train_Y[0])