python: MNIST的图像显示
1 import sys, os 2 sys.path.append(os.pardir) 3 import numpy as np 4 from dataset.mnist import load_mnist 5 from PIL import Image 6 7 def img_show(img): 8 pil_img = Image.fromarray(np.uint8(img)) 9 pil_img.show() 10 11 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, 12 normalize=False) 13 img = x_train[0] 14 label = t_train[0] 15 print(label) # 5 16 17 print(img.shape) # (784,) 18 img = img.reshape(28, 28) # 把图像的形状变成原来的尺寸 19 print(img.shape) # (28, 28) 20 21 img_show(img)
显示mnist图像,执行上述代码后,训练图像的第一张会显示出来。sys.path.append(os.pardir)导入父目录,第一次调用load_mnist函数时,因为要下载MNIST数据集,所以需要联网进行。第2次及以后的调用只需要读入保存在本地的文件(pickle文件)即可,因此处理所需时间都非常短。
load_mnist函数以“(训练图像,训练标签),(测试图像,测试标签)”的形式返回读入的MNST数据。此外,还可以像
load_mnist(normalize=True, flatten=True, one_hot_label=False)
这样,设置3个参数。第1个参数normalize设置是否将输入图像正规化为0.0~1.0的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。第2个参数flatten设置是否展开输入图像(变成一维数据)。如果将该参数设置为False,则输入图像为1 × 28 × 28的三维数组;若设置为True,则输入图像会保存为由784个元素构成的一位数组。第3个参数one_hot_label设置是否将标签保存为one-hot表示(one-hot representation)。one-hot表示是仅正确解标签为1,其余皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,知识想7,2这样简单保存正确解标签;当one_hot_label为True时,标签则保存为one-hot表示。
Python 有 pickle 这个便利的功能。这个功能可以将程序运行中的对象保存为文件。如果加载保存过的 pickle 文件,可以立刻复原之前程序运行中的对象。用于读入 MNIST 数据集的
load_mnist()
函数内部也使用了 pickle 功能(在第 2 次及以后读入时)。利用 pickle 功能,可以高效地完成 MNIST 数据的准备工作。
这里需要注意的是,flatten=True
时读入的图像是以一列(一维)NumPy 数组的形式保存的。因此,显示图像时,需要把它变为原来的 28 像素 × 28 像素的形状。可以通过 reshape()
方法的参数指定期望的形状,更改 NumPy 数组的形状。此外,还需要把保存为 NumPy 数组的图像数据转换为 PIL 用的数据对象,这个转换处理由 Image.fromarray()
来完成。
img = x_train[0] #x_train的形状是(6000,784),即6000行728列的矩阵,所以x_train[0]表示第一列的784个数据
label = t_train[0] #t_train的形状是(6000,),即一行或者一列数据6000个,所以t_train[0]是第一个数据,这里它的值是5