《TensorFlow之tf.keras的基础分类》的代码注释***(学习笔记一)***(Basic classification: Classify images of clothing )
本文源代码来自于tensorflow官网 https://www.tensorflow.org/tutorials/keras/classification
学习中参考了若干博客 如https://blog.csdn.net/qq_20989105/article/details/82760815 等
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
#********************预处理数据***********************
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
plt.figure() #创建一个图框,figure(num=None 编号/名称, figsize=None 图框尺寸, dpi=None 分辨率, facecolor=None 背景颜色, edgecolor=None 边框颜色, frameon=True 是否显示边框)
plt.imshow(train_images[0]) #展现数据
plt.colorbar() #添加颜色线
plt.grid(False) #图片网格=True/False
plt.show()
train_images = train_images / 255.0 #归一化
test_images = test_images / 255.0
plt.figure(figsize=(10,10)) #open a 10*10 windows
for i in range(25):
plt.subplot(5,5,i+1) #plt.subplot(a,b,i+1)中 a表示图片行数,b表示图片列数,i为索引值,i+1即可表示图片具体位置,从左往右,从上往下分布,这边i=0,i+1则表示第一个左上角第一个位置
plt.xticks([]) #plt.xticks() 表达的是x轴的刻度内容的范围
plt.yticks([])
plt.grid(False) #图片网格
plt.imshow(train_images[i], cmap=plt.cm.binary) #plt.cm.binary(黑白)/plt.cm.gray(灰度图)/plt.cm.bone/plt.cm.hot
plt.xlabel(class_names[train_labels[i]])
plt.show()
#********************建立模型***********************
#***设置图层***
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)), #该网络中的第一层tf.keras.layers.Flatten将图像的格式从2d阵列(28乘28像素)转换为28 * 28 = 784像素的1d阵列。
#可以将此图层视为图像中未堆叠的像素行并将其排列。该层没有要学习的参数; 它只重新格式化数据。
tf.keras.layers.Dense(128, activation='relu'), #在像素被展平之后,网络由tf.keras.layers.Dense两层序列组成。这些是密集连接或完全连接的神经层,第一Dense层有128个节点(或神经元)。
tf.keras.layers.Dense(10) #第二(和最后)层是10节点softmax层 - 这返回10个概率分数的数组,其总和为1.每个节点包含指示当前图像属于10个类之一的概率的分数。
])
#***编译模型***
model.compile(optimizer='adam', #optimizer = 优化器,optimizer可以是字符串形式给出的优化器名字,adam收敛最快,BGD梯度下降法(最原始,也是最基础),SGD随机梯度下降法
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), #交叉熵损失函数,from_logits: 为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定
metrics=['accuracy']) #度量标准 - 用于监控培训和测试步骤。accuracy :使用精度,即正确分类的图像的分数
#***训练模型***
model.fit(train_images, train_labels, epochs=10) #调用model.fit方法 - 模型“适合”训练数据,将训练数据提供给模型 - 在此示例中为train_images和train_labels数组。
#***评估模型准确性***
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) #比较模型在测试数据集上的执行情况。测试数据集的准确性略低于训练数据集的准确性。
print('\nTest accuracy:', test_acc) #训练精度和测试精度之间的差距是过度拟合的一个例子。过度拟合是指机器学习模型在新数据上的表现比在训练数据上表现更差。
#***做出预测***
predictions = model.predict(test_images)
#***图表来查看全部的10个类别***
def plot_image(i, predictions_array, true_label, img): #设置图片展示,x轴下标 预测类别,概率,识别类别
true_label, img = true_label[i], img[i]
plt.grid(False) #图片网格
plt.xticks([]) #x刻度设置为0
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary) #plt.cm.binary(二值图/黑白图)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label], #对图框的xlabel进行标记, .format 转换字符串格式,输出类标 如 ankle boot
100*np.max(predictions_array), #100*np.max(predictions_array)={:2.0f}% ,(以至少两位数的float格式输出*100)%。如37%
class_names[true_label]), #输出 (类标) ,如 (ankle boot)
color=color) #从上面的color传递颜色,分类正确为蓝色,分类错误为红色
def plot_value_array(i, predictions_array, true_label): #设置柱状图
true_label = true_label[i]
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777") #横坐标0-9,颜色灰色
plt.ylim([0, 1]) #纵坐标0-1(概率)
predicted_label = np.argmax(predictions_array) #numpy.argmax(array, axis) 用于返回一个numpy数组中最大值的索引值,这边返回预测概率最高的分类
thisplot[predicted_label].set_color('red') #将预测的颜色设置为红色
thisplot[true_label].set_color('blue') #将实际颜色设置为蓝色,如果预测=实际,则覆盖掉,否则则会显示出两条颜色不一样的柱形,也代表着分类错误
i = 0
plt.figure(figsize=(6,3)) #matplotlib下, 一个 Figure 对象可以包含多个子图(Axes), 可以使用 subplot() 快速定位=>绘制
plt.subplot(1,2,1) #subplot(numRows, numCols, plotNum)图表的整个绘图区域被分成 numRows 行和 numCols 列,左上的子区域的编号为1,plotNum 参数指定创建的 Axes 对象所在区域的索引
plot_image(i, predictions[i], test_labels, test_images) #调用定义好的plot_image 函数,传入四个参数
plt.subplot(1,2,2)
plot_value_array(i, predictions[i], test_labels)
plt.show()
i = 12
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i], test_labels)
plt.show()
# Plot the first X test images, their predicted labels, and the true labels.
# Color correct predictions in blue and incorrect predictions in red.
num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows)) #这边2*2*num_cols 实际上是num_col*2(每个模块有两张图,占一行的两列),再放大两倍(否则下标太小看不见)
for i in range(num_images):
plt.subplot(num_rows, 2*num_cols, 2*i+1) #定位画图的位置
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(i, predictions[i], test_labels)
plt.tight_layout() #plt.tight_layout会自动调整子图参数(坐标轴标签、刻度标签以及标题的部分),使之填充整个图像区域
plt.show()