关于tensorflow2.x保存模型及加载模型的方法及对比
以下方法都是个人实际中测试和使用的方法,tf2版本在2.3~2.7之间
1、model.save() and model.load()
保存模型:这个方法可以直接将训练后的权重和训练的参数保存下来,一般我个人使用的.h5为后缀把模型整个保存下来。
步骤如下:
(1)创建模型,像添加积木一样对模型添加需要的卷积,池化等操作
(2)配置神经网络的优化器,计算梯度的方法
(3)保存模型
加载模型:这样保存下来的模型可以直接将h5文件保存下来,不需要先加载模型
2、model.save_weight() and model.load_weight()
(1)这里采用继承Model这个类去实现神经网络(比第一种方法更加常用且受规范)
下面的方法就是当我们保存模型的权重参数,但是没有保存模型的结构
加载模型
需要先把模型的结构导入过来,再load模型的参数进去才能进行推理
3、model.checkpoint
这个用的比较少,看这样加载模型的方式,可能跟第二种类似
tf1.x版本中的Checkpoint用法如下
保存模型的参数
注意:
由于本次测试使用的是对抗生成网络,所有两个网络,即一个判别器和一个生成器,同时对应2种优化器,Checkpoint中设置你的生成器和优化器的名字,然后在加载的时候以同样的名字为参数传入模型的结构
载入模型
注意:这里的模型都指定了输入的图片的尺寸,如果想输入的图片尺寸不受限制,那么不要使用flatten拉直神经网络,可以使用全卷积来再进行softmax即可
tf.keras.layers.GlobalAveragePooling2D()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律