tensorflow2.0 - 保存模型(含自定义模型的保存)
tensorflow2.0保存模型的方式有很多,这里只介绍两种。
一、 使用官方模型
这种情况可以直接保存整个模型,如下所示,可以将模型保存为HDF5文件
# 创建模型实例
model = create_model()
# 保存模型到HDF5文件
model.save('my_model.h5')
# 读取模型
model = keras.models.load_model('my_model.h5')
二、自定义模型
如果是自定义模型使用上述方法保存会报错且保存失败,报错为:
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn’t safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format=“tf”) or using
save_weights
.
这种情况可以保存weight。
# 创建模型
model = create_model()
# 保存权重
model.save_weights('model_weight')
# 创建新模型读取权重
newModel = create_model()
# 读取权重到新模型
newModel.load_weights('model_weight')
参考文献: