关于tensorflow2.x保存模型及加载模型的方法及对比

以下方法都是个人实际中测试和使用的方法,tf2版本在2.3~2.7之间

1、model.save() and model.load()

保存模型:这个方法可以直接将训练后的权重和训练的参数保存下来,一般我个人使用的.h5为后缀把模型整个保存下来。

步骤如下:

(1)创建模型,像添加积木一样对模型添加需要的卷积,池化等操作

 (2)配置神经网络的优化器,计算梯度的方法  

 (3)保存模型

 

加载模型:这样保存下来的模型可以直接将h5文件保存下来,不需要先加载模型

 

 

2、model.save_weight() and model.load_weight()

 (1)这里采用继承Model这个类去实现神经网络(比第一种方法更加常用且受规范)

 下面的方法就是当我们保存模型的权重参数,但是没有保存模型的结构

 加载模型

需要先把模型的结构导入过来,再load模型的参数进去才能进行推理

 

3、model.checkpoint

 这个用的比较少,看这样加载模型的方式,可能跟第二种类似

 

tf1.x版本中的Checkpoint用法如下

 

保存模型的参数

注意:

由于本次测试使用的是对抗生成网络,所有两个网络,即一个判别器和一个生成器,同时对应2种优化器,Checkpoint中设置你的生成器和优化器的名字,然后在加载的时候以同样的名字为参数传入模型的结构

 载入模型

 

注意:这里的模型都指定了输入的图片的尺寸,如果想输入的图片尺寸不受限制,那么不要使用flatten拉直神经网络,可以使用全卷积来再进行softmax即可

tf.keras.layers.GlobalAveragePooling2D()

posted @ 2023-07-10 21:32  waterdoor  阅读(285)  评论(4编辑  收藏  举报