加载 MNIST 数据集
一、MNIST 数据集简介
- MNIST 数据集是机器学习领域中非常经典的一个数据集,由 60000 个训练样本和 10000 个测试样本组成,每个样本都是一张 28 * 28 像素的灰度手写数字图片,如下图所示:
- MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,它包含了四个部分:训练集、训练集标签、测试集、测试集标签。
- MNIST 中的每个图像都具有相应的标签,0 到 9 之间的数字表示图像中绘制的数字, 用的是 one-hot 编码 nn[0,0,0,0,0,0,1,0,0,0],mnist.train.labels[55000,10]。
二、加载 MNIST 数据集
- 直接下载下来的数据是无法通过解压或者应用程序打开的,因为这些文件不是任何标准 的图像格式而是以字节的形式进行存储的,所以必须编写程序来打开它。
- 使用 TensorFlow 来读取数据及标签
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# 加载数据集
mnist = input_data.read_data_sets('e:/soft/MNIST_DATA',one_hot=True)
# 加载训练集样本
train_x = mnist.train.images
# 加载验证集样本
validation_x = mnist.validation.images
# 加载测试集样本
test_x = mnist.test.images
# 加载训练集标签
train_y = mnist.train.labels
# 加载验证集标签
validation_y = mnist.validation.labels
# 加载测试集标签
test_y =mnist.test.labels
print('train_x.shape:',train_x.shape,'train_y.shape:',train_y.shape)
# 查看训练集中第一个样本的内容和标签
print(train_x[0])
print(train_y[0])
# 获取训练集数据的前200个
images,labels = mnist.train.next_batch(200)
print('images.shape:',images.shape,'labels.shape:',labels.shape)
import matplotlib.pyplot as plt
# 绘制训练集前20个样本
fig,ax = plt.subplots(nrows=4,ncols=5)
ax = ax.flatten()
for i in range(20):
img = train_x[i].reshape(28,28)
ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.show()
三、训练结果
- 训练集中前 40 个样本图形如下:
正是江南好风景