mnist数据集探究
一、mnist的属性和方法
为了方便我只检查了后20个属性和方法
1 from tensorflow.examples.tutorials.mnist import input_data 2 3 mnist = input_data.read_data_sets('G:\MNIST DATABASE\MNIST_data',one_hot=True) 4 print(dir(mnist)[-20:])
1:从tensorflow.examples.tutorials.mnist库中导入input_data文件
3:调用input_data文件的read_data_sets方法,需要2个参数,第1个参数的数据类型是字符串,是读取数据的文件夹名,第2个关键字参数ont_hot数据类型为布尔bool,设置为True,表示预测目标值是否经过One-Hot编码;
4:打印mnist后20个属性和方法
结果:
Extracting G:\MNIST DATABASE\MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting G:\MNIST DATABASE\MNIST_data\t10k-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting G:\MNIST DATABASE\MNIST_data\t10k-labels-idx1-ubyte.gz
['__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '_asdict', '_fields', '_make', '_replace', '_source', 'count', 'index', 'test', 'train', 'validation']
二、查看mnist里的训练集、验证集、测试集包括多少图片
train集合有55000张图片,validation集合有5000张图片,这两个集合组成MNIST本身提供的训练数据集
1 print('训练数据数量',mnist.train.num_examples) 2 print('验证数据数量',mnist.validation.num_examples) 3 print('测试数据数量',mnist.test.num_examples) 4 5 #结果: 6 训练数据数量 55000 7 验证数据数量 5000 8 测试数据数量 10000
三、mnist.train.next_batch()函数
input_data.read_data_sets函数生成的类提供的mnist.train.next_batch()函数,它可以从所有的训练数据中读取一小部分作为一个训练batch
1 batch_size = 100
#从train集合中选取100个训练数据,100个训练数据的标签 2 xs,ys = mnist.train.next_batch(batch_size) 3 print('xs shape',xs.shape) 4 print('ys shape',ys.shape) 5 6 #结果: 7 xs shape (100, 784) 8 ys shape (100, 10)
四、mnist.train.images观察
mnist.train.images的数据类型是数组,每一个数据是一位数组,每个数据一维数组的长度是784,即每张图片的像素数
1 print('train集合数据的类型:',type(mnist.train.images),'train集合数据矩阵形状:',mnist.train.images.shape) 2 print('train集合数据标签的类型:',type(mnist.train.labels),'train集合数据标签矩阵形状:',mnist.train.labels.shape) 3 4 #结果: 5 train集合数据的类型: <class 'numpy.ndarray'> train集合数据矩阵形状: (55000, 784) 6 train集合数据标签的类型: <class 'numpy.ndarray'> train集合数据标签矩阵形状: (55000, 10) 7 8 print('train集合第一个数据长度、内容:',len(mnist.train.images[0]),mnist.train.images[0]) 9 print('train集合第一个数据标签长度、内容:',len(mnist.train.labels[0]),mnist.train.labels[0]) 10 11 结果: 12 train集合第一个数据长度、内容: 784 [ 0. 0. 0. 0. 0. 0. 0. 13 0. 0. 0. 0. 0. 0. 0. 14 0. 0. 0. 0. 0. 0. 0. 15 0. 0. 0. 0. 0. 0. 0. 16 0. 0. 0. 0. 0. 0. 0. 17 0. 0. 0. 0. 0. 0. 0. 18 0. 0. 0. 0. 0. 0. 0. 19 0. 0. 0. 0. 0. 0. 0. 20 0. 0. 0. 0. 0. 0. 0. 21 0. 0. 0. 0. 0. 0. 0. 22 0. 0. 0. 0. 0. 0. 0. 23 0. 0. 0. 0. 0. 0. 0. 24 0. 0. 0. 0. 0. 0. 0. 25 0. 0. 0. 0. 0. 0. 0. 26 0. 0. 0. 0. 0. 0. 0. 27 0. 0. 0. 0. 0. 0. 0. 28 0. 0. 0. 0. 0. 0. 0. 29 0. 0. 0. 0. 0. 0. 0. 30 0. 0. 0. 0. 0. 0. 0. 31 0. 0. 0. 0. 0. 0. 0. 32 0. 0. 0. 0. 0. 0. 0. 33 0. 0. 0. 0. 0. 0. 0. 34 0. 0. 0. 0. 0. 0. 0. 35 0. 0. 0. 0. 0. 0. 0. 36 0. 0. 0. 0. 0. 0. 0. 37 0. 0. 0. 0. 0. 0. 0. 38 0. 0. 0. 0. 0. 0. 0. 39 0. 0. 0. 0. 0. 0. 0. 40 0. 0. 0. 0. 0. 0. 0. 41 0. 0. 0. 0. 0.38039219 0.37647063 42 0.3019608 0.46274513 0.2392157 0. 0. 0. 0. 43 0. 0. 0. 0. 0. 0. 0. 44 0. 0. 0. 0. 0.35294119 0.5411765 45 0.92156869 0.92156869 0.92156869 0.92156869 0.92156869 0.92156869 46 0.98431379 0.98431379 0.97254908 0.99607849 0.96078438 0.92156869 47 0.74509805 0.08235294 0. 0. 0. 0. 0. 48 0. 0. 0. 0. 0. 0. 49 0.54901963 0.98431379 0.99607849 0.99607849 0.99607849 0.99607849 50 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 51 0.99607849 0.99607849 0.99607849 0.99607849 0.74117649 0.09019608 52 0. 0. 0. 0. 0. 0. 0. 53 0. 0. 0. 0.88627458 0.99607849 0.81568635 54 0.78039223 0.78039223 0.78039223 0.78039223 0.54509807 0.2392157 55 0.2392157 0.2392157 0.2392157 0.2392157 0.50196081 0.8705883 56 0.99607849 0.99607849 0.74117649 0.08235294 0. 0. 0. 57 0. 0. 0. 0. 0. 0. 58 0.14901961 0.32156864 0.0509804 0. 0. 0. 0. 59 0. 0. 0. 0. 0. 0. 0. 60 0.13333334 0.83529419 0.99607849 0.99607849 0.45098042 0. 0. 61 0. 0. 0. 0. 0. 0. 0. 62 0. 0. 0. 0. 0. 0. 0. 63 0. 0. 0. 0. 0. 0. 0. 64 0. 0.32941177 0.99607849 0.99607849 0.91764712 0. 0. 65 0. 0. 0. 0. 0. 0. 0. 66 0. 0. 0. 0. 0. 0. 0. 67 0. 0. 0. 0. 0. 0. 0. 68 0. 0.32941177 0.99607849 0.99607849 0.91764712 0. 0. 69 0. 0. 0. 0. 0. 0. 0. 70 0. 0. 0. 0. 0. 0. 0. 71 0. 0. 0. 0. 0. 0. 0. 72 0.41568631 0.6156863 0.99607849 0.99607849 0.95294124 0.20000002 73 0. 0. 0. 0. 0. 0. 0. 74 0. 0. 0. 0. 0. 0. 0. 75 0. 0. 0. 0.09803922 0.45882356 0.89411771 76 0.89411771 0.89411771 0.99215692 0.99607849 0.99607849 0.99607849 77 0.99607849 0.94117653 0. 0. 0. 0. 0. 78 0. 0. 0. 0. 0. 0. 0. 79 0. 0. 0. 0.26666668 0.4666667 0.86274517 80 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 0.99607849 81 0.99607849 0.99607849 0.99607849 0.55686277 0. 0. 0. 82 0. 0. 0. 0. 0. 0. 0. 83 0. 0. 0. 0.14509805 0.73333335 0.99215692 84 0.99607849 0.99607849 0.99607849 0.87450987 0.80784321 0.80784321 85 0.29411766 0.26666668 0.84313732 0.99607849 0.99607849 0.45882356 86 0. 0. 0. 0. 0. 0. 0. 87 0. 0. 0. 0. 0. 0.44313729 88 0.8588236 0.99607849 0.94901967 0.89019614 0.45098042 0.34901962 89 0.12156864 0. 0. 0. 0. 0.7843138 90 0.99607849 0.9450981 0.16078432 0. 0. 0. 0. 91 0. 0. 0. 0. 0. 0. 0. 92 0. 0.66274512 0.99607849 0.6901961 0.24313727 0. 0. 93 0. 0. 0. 0. 0. 0.18823531 94 0.90588242 0.99607849 0.91764712 0. 0. 0. 0. 95 0. 0. 0. 0. 0. 0. 0. 96 0. 0. 0.07058824 0.48627454 0. 0. 0. 97 0. 0. 0. 0. 0. 0. 98 0.32941177 0.99607849 0.99607849 0.65098041 0. 0. 0. 99 0. 0. 0. 0. 0. 0. 0. 100 0. 0. 0. 0. 0. 0. 0. 101 0. 0. 0. 0. 0. 0. 0. 102 0.54509807 0.99607849 0.9333334 0.22352943 0. 0. 0. 103 0. 0. 0. 0. 0. 0. 0. 104 0. 0. 0. 0. 0. 0. 0. 105 0. 0. 0. 0. 0. 0. 106 0.82352948 0.98039222 0.99607849 0.65882355 0. 0. 0. 107 0. 0. 0. 0. 0. 0. 0. 108 0. 0. 0. 0. 0. 0. 0. 109 0. 0. 0. 0. 0. 0. 0. 110 0.94901967 0.99607849 0.93725497 0.22352943 0. 0. 0. 111 0. 0. 0. 0. 0. 0. 0. 112 0. 0. 0. 0. 0. 0. 0. 113 0. 0. 0. 0. 0. 0. 114 0.34901962 0.98431379 0.9450981 0.33725491 0. 0. 0. 115 0. 0. 0. 0. 0. 0. 0. 116 0. 0. 0. 0. 0. 0. 0. 117 0. 0. 0. 0. 0. 0. 118 0.01960784 0.80784321 0.96470594 0.6156863 0. 0. 0. 119 0. 0. 0. 0. 0. 0. 0. 120 0. 0. 0. 0. 0. 0. 0. 121 0. 0. 0. 0. 0. 0. 0. 122 0.01568628 0.45882356 0.27058825 0. 0. 0. 0. 123 0. 0. 0. 0. 0. 0. 0. 124 0. 0. 0. 0. 0. 0. 0. 125 0. 0. 0. 0. 0. 0. 0. 126 0. 0. 0. 0. 0. 0. 0. 127 0. 0. 0. 0. 0. 0. 0. ] 128 train集合第一个数据标签长度、内容: 10 [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
从上面的运行结果可以看出,在变量mnist.train中总共有55000个样本,每个样本有784个特征。
原图片形状为28*28,28*28=784
,每个图片样本展平后则有784维特征。
选取1个样本,用3种作图方式查看其图片内容,代码如下:
1 #将数组张换成图片形式 2 image = mnist.train.images[1].reshape(-1,28) 3 fig = plt.figure("图片展示") 4 ax0 =fig.add_subplot(131) 5 ax0.imshow(image) 6 ax0.axis('off') #不显示坐标尺寸 7 8 plt.subplot(132) 9 plt.imshow(image,cmap='gray') 10 plt.axis('off')#不显示坐标尺寸 11 12 plt.subplot(133) 13 plt.imshow(image,cmap='gray_r') 14 plt.axis('off') 15 plt.show()
结果:
从上面的运行结果可以看出,调用plt.show方法时,参数cmap指定值为gray或gray_r符合正常的观看效果。
五、查看手写数字图
从训练集mnist.train中选取一部分样本查看图片内容,即调用mnist.train的next_batch方法随机获得一部分样本,代码如下
1 from tensorflow.examples.tutorials.mnist import input_data 2 import math 3 import matplotlib.pyplot as plt 4 import numpy as np 5 mnist = input_data.read_data_sets('G:\MNIST DATABASE\MNIST_data',one_hot=True) 6 #画单张mnist数据集的数据 7 def drawdigit(position,image,title): 8 plt.subplot(*position) 9 plt.imshow(image,cmap='gray_r') 10 plt.axis('off') 11 plt.title(title) 12 13 #取一个batch的数据,然后在一张画布上画batch_size个子图 14 def batchDraw(batch_size): 15 images,labels = mnist.train.next_batch(batch_size) 16 row_num = math.ceil(batch_size ** 0.5) 17 column_num = row_num 18 plt.figure(figsize=(row_num,column_num)) 19 for i in range(row_num): 20 for j in range(column_num): 21 index = i * column_num + j 22 if index < batch_size: 23 position = (row_num,column_num,index+1) 24 image = images[index].reshape(-1,28) 25 title = 'actual:%d'%(np.argmax(labels[index])) 26 drawdigit(position,image,title) 27 28 29 if __name__ == '__main__': 30 batchDraw(196) 31 plt.show()
结果: