【3】TensorFlow光速入门-训练及评估

本文地址:https://www.cnblogs.com/tujia/p/13862357.html

 

系列文章:

【0】TensorFlow光速入门-序

【1】TensorFlow光速入门-tensorflow开发基本流程

【2】TensorFlow光速入门-数据预处理(得到数据集)

【3】TensorFlow光速入门-训练及评估

【4】TensorFlow光速入门-保存模型及加载模型并使用

【5】TensorFlow光速入门-图片分类完整代码

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

【7】TensorFlow光速入门-总结

 

一、导入需要的包

import tensorflow as tf
from tensorflow import keras
import numpy as np

 

二、初始化模型并配置神经网络层

model = keras.Sequential([
    # 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
    keras.layers.Flatten(input_shape=(100, 100, 1)),
    # 第二层是神经元
    keras.layers.Dense(128, activation='relu'),
    # 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
    keras.layers.Dense(2, activation='softmax')
])

注:如果是图片分类,这三层网络是固定搭配,需要注意的是,input_shape要和数据集数据保持一致,第三层分几类就传几;其他模型的层选择和配置,我们后面再慢慢了解

 

三、模型编译

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

注:同样,图片分类的优化器、损失函数及指标也是固定搭配,其他类型的模型我们后面再慢慢了解

 

四、训练

model.fit(ds, epochs=100, steps_per_epoch=10)

注1:ds 是上一节准备好的数据集;epochs 代表要训练多少次,steps_per_epoch 代表每一次分几步训练;因为我准备的数据比较少,所以设置的训练100次。数据多的话,不用训练那么多次。

注2:使用 ZipDataset 类型的数据集时,steps_per_epoch 参数为必填,其他情况,根据自己的情况可以不传。

 

五、评估(评估训练效果)

test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)

注1:正常情况下,训练要用训练集,评估要用测试集。因为偷懒的原故,这里我都用的同一个数据集。

注2:使用 ZipDataset 类型的数据集时,steps 参数为必填,其他情况,根据自己的情况可以不传。

 

六、预测

预测即使用的意思,评估通过的模型,可以直接使用了

predictions = model.predict(ds, steps=10)
label = np.argmax(predictions[0])
print(label_names[label])

注:这里批量预测,对整个数据集都进行预测,正式使用的时候,一般只预测一张图片就可以了,下一节会说。

 

重点 Api :

keras.Sequential        https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential

model.compile            https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential#compile

model.fit                       https://tensorflow.google.cn/api_docs/python/tf/keras/Model#fit

model.evaluate            https://tensorflow.google.cn/api_docs/python/tf/keras/Model#evaluate

model.predict               https://tensorflow.google.cn/api_docs/python/tf/keras/Model#predict

 

至此,我们的图片分类模型已经训练好了。可以使用了模型来做图片分类预测了。

下一节,让我们来说一下,怎么保存这个训练好的模型。以及如何加载已保存的模型并使用:

【4】TensorFlow光速入门-保存模型及加载模型并使用

 

本文链接:https://www.cnblogs.com/tujia/p/13862357.html


 完。

posted @ 2020-10-23 17:26  Tiac  阅读(640)  评论(0编辑  收藏  举报