【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html
系列文章:
【1】TensorFlow光速入门-tensorflow开发基本流程
【2】TensorFlow光速入门-数据预处理(得到数据集)
【4】TensorFlow光速入门-保存模型及加载模型并使用
【6】TensorFlow光速入门-python模型转换为tfjs模型并使用
一、保存模型
创建一个目录
!mkdir /tf/saved_model
注:jupyter 代码块前面加一个!号表示,这是shell命令,不是代码;
保存模型
model.save('/tf/saved_model/wnw')
保存模型的其他参数及操作,看这里 https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save
二、加载模型
import tensorflow as tf from tensorflow import keras import numpy as np from IPython import display import random # 加载模型 model = keras.models.load_model('/tf/saved_model/wnw') # 看一下模型的结构 model.summary() # 随便找点图片 all_image_paths = [] data_root = pathlib.Path('/tf/datasets/wnw') for item in data_root.rglob('*.jpg'): all_image_paths.append(str(item)) print(len(all_image_paths)) # 随机选取一张图片 img_path = random.choice(all_image_paths) print(img_path) # 把图片处理成需要的tensor image = tf.io.read_file(img_path) image = tf.image.decode_image(image, channels=1) image = tf.image.resize(image, (100, 100)) image /= 255 print(image.shape) # 预测只支持批量操作,我们给单张图片再加一维 images = (np.expand_dims(image, 0)) print(images.shape) # 预测 predictions = model.predict(images) # 打印结果 label_names = ['other', 'watch'] label = np.argmax(predictions[0]) print(label_names[label]) # 把图片也打印出来,看一下预测效果对不对 display.display(display.Image(img_path, width=200, height=200))
注:
用于预测的图片数据要和训练的图片数据保持一致:
简单来说,训练不一定要100*100的灰图,我可以是80*80的灰图或彩图,都没关系。
重要的是,用使用模型的时候,要先把预测数据转换成训练集数据一样的格式
重点:
model.save https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save
keras.models.load_model https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model
至此,我们已经可以加载并使用模型了。我们可以用python封装程序成web服务api,以供调用。不过像图片分类这一类,频繁的拍照上传图片调用api也不太好。
这里,我们已经实现了在【序】里说的一个小目标:使用模型!!
在网上下载的第三方开源模型,只要知道它的用途及其输入参数(input_shape)数据格式,我们就可以用 tf.io、tf.image、tf.data.Dataset 等api接口处理数据成所需格式,然后就可以直接评测(使用)了
下一节,我们先整理一下图片分类的完整代码,然后下下节,我们再说一下怎样使用tfjs直接加载模型(不需要调python服务)
【6】TensorFlow光速入门-python模型转换为tfjs模型并使用
本文链接:https://www.cnblogs.com/tujia/p/13862360.html
完。