一些常用的数据集操作技巧
#查看当前tensorflow的版本 tf.__version__ #标签可以用数组做映射 class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] #查看数据的形状,对数据有直观的了解。 train_images.shape (60000,28,28) #查看训练集总数 len(train_labels) 60000 #如果是图像,可以把数据集的图像拿出一个看看 import matplotlib.pyplot as plt plt.figure() plt.imshow(train_images[0]) plt.show() #注意代码的层级性,figure是大容器,下面有2个元素,imshow和show,其中imshow又是一个子容器,里面有元素train_images[0]。分清层级,有助于理解细节。下面是更详细的写法 import matplotlib.pyplot as plt plt.figure(figsize=(1,1)) # 添加figsize参数,缩小图像尺寸 plt.imshow(train_images[0]) plt.colorbar() #增加颜色条 plt.grid(False) #去掉网格,黑色背景网格不明显 plt.show() #可以看到,大容器figure缩小后,里面的图像也跟着缩小了。 #批量查看缩小图 plt.figure(figsize=(10,10)) for i in range(25): plt.subplot(5,5,i+1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(train_images[i], cmap=plt.cm.binary) plt.xlabel(class_names[train_labels[i]]) plt.show() #挑选单个数据测试 import numpy as np img = test_images[5] img = np.expand_dims(img,0) single = model.predict(img) single_index = np.argmax(single[0]) class_names[single_index] #输出:Trouser 裤子 #查看具体的图像,看看是否正确 import matplotlib.pyplot as plt plt.figure(figsize=(1,1)) plt.imshow(test_images[5]) plt.show()