TebsorFlow低阶API(五)—— 保存和恢复

简介

tf.train.Saver 类提供了保存和恢复模型的方法。通过 tf.saved_model.simple_save 函数可以轻松地保存适合投入使用的模型。Estimator会自动保存和恢复 model_dir 中的变量。

 

保存和恢复变量

TensorFlow变量是表示由程序操作的共享持久状态的最佳方法。tf.train.Saver 构造函数会针对图中的所有变量或指定列表的变量将 save 和 restore 操作添加到图中。Saver对象提供了运行这些操作的方法,并指定写入或读取检查点文件的路径。

Saver 会恢复已经在模型中定义的所有变量。如果您在不知道如何构件图的情况下加载模型(例如,您要编写用于加载各类模型的通用程序),那么请阅读本文档后面的保存和恢复模型概述部分。

TensorFlow将变量保存在二进制检查点文件中,这类文件会将变量名称映射到张量值。

注意:TensorFlow 模型文件是代码。请注意不可信的代码。详情请参阅安全地使用 TensorFlow

保存变量

创建Saver(使用 tf.train.Saver())来管理模型中的所有变量。例如,以下代码展示了如何调用 tf.train.Saver.save 方法以将变量保存到检查点文件中:

 1 # Create some variables.
 2 v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
 3 v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
 4 
 5 inc_v1 = v1.assign(v1+1)
 6 dec_v2 = v2.assign(v2-1)
 7 
 8 # Add an op to initialize the variables.
 9 init_op = tf.global_variables_initializer()
10 
11 # Add ops to save and restore all the variables.
12 saver = tf.train.Saver()
13 
14 # Later, launch the model, initialize the variables, do some work, and save the variables to disk.
15 with tf.Session() as sess:
16   sess.run(init_op)
17   # Do some work with the model.
18   inc_v1.op.run()
19   dec_v2.op.run()
20   # Save the variables to disk.
21   save_path = saver.save(sess, "/tmp/model.ckpt")
22   print("Model saved in path: %s" % save_path)

恢复变量

tf.train.Saver 对象不仅将变量保存在检查点文件中,还将恢复变量。请注意,当您恢复变量时,您不必事先将其初始化。例如,以下代码段展示了如何调用 tf.train.Saver.restore 方法以从检查点文件中恢复变量:

 1 tf.reset_default_graph()
 2 
 3 # Create some variables.
 4 v1 = tf.get_variable("v1", shape=[3])
 5 v2 = tf.get_variable("v2", shape=[5])
 6 
 7 # Add ops to save and restore all the variables.
 8 saver = tf.train.Saver()
 9 
10 # Later, launch the model, use the saver to restore variables from disk, and
11 # do some work with the model.
12 with tf.Session() as sess:
13   # Restore variables from disk.
14   saver.restore(sess, "/tmp/model.ckpt")
15   print("Model restored.")
16   # Check the values of the variables
17   print("v1 : %s" % v1.eval())
18   print("v2 : %s" % v2.eval())

注意:并没有名为 /tmp/model.ckpt 的实体文件。它是为检查点创建的文件名的前缀。用户仅与前缀(而非检查点实体文件)互动。

选择要保存和恢复的变量

如果您没有向 tf.train.Saver()传递任何参数,则Saver会处理图中的所有变量。每个变量都保存在创建变量时所传递的名称下。

在检查点文件中明确指定变量名称的这种做法有时非常有用。例如,您可能已经使用名为“weights” 的变量训练了一个模型,而您想要将该变量的值恢复到名为“params”的变量中。

有时候,仅保存和恢复模型使用的变量子集也会很有裨益。例如,您可能已经训练了一个五层的神经网络,现在您训练一个六层的新模型,并重用该五层的现有权重。您可以使用Saver只恢复这前五层的权重。

您可以通过向 tf.train.Saver() 构造函数传递以下任一内容,轻松指定要保存或加载的名称和变量:

  • 变量列表(将以其本身的名称保存)。
  • Python字典,其中,键是要使用的名称,键值是要管理的变量。

继续前面所示的保存/恢复示例:

 1 tf.reset_default_graph()
 2 # Create some variables.
 3 v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
 4 v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
 5 
 6 # Add ops to save and restore only `v2` using the name "v2"
 7 saver = tf.train.Saver({"v2": v2})
 8 
 9 # Use the saver object normally after that.
10 with tf.Session() as sess:
11   # Initialize v1 since the saver will not.
12   v1.initializer.run()
13   saver.restore(sess, "/tmp/model.ckpt")
14 
15   print("v1 : %s" % v1.eval())
16   print("v2 : %s" % v2.eval())

注意:

  • 如果要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的 Saver对象。同一个变量可以列在多个Saver对象中,变量的值只有在Saver.restore()方法运行时才会更改。
  • 如果您在会话开始时仅恢复一部分模型变量,则必须为其它变量运行初始化操作。
  • 要检查某个检查点中的变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。默认情况下,Saver会针对每个变量使用 tf.Variable.name 属性的值。但是,当您创建Saver对象时,您可以选择为检查点中的变量选择名称。

检查某个检查点中的变量

我们可以使用 inspect_checkpoint 库快速检查某个检查点中的变量。

继续前面所示的保存/恢复示例:

 1 # import the inspect_checkpoint library
 2 from tensorflow.python.tools import inspect_checkpoint as chkp
 3 
 4 # print all tensors in checkpoint file
 5 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
 6 
 7 # tensor_name:  v1
 8 # [ 1.  1.  1.]
 9 # tensor_name:  v2
10 # [-1. -1. -1. -1. -1.]
11 
12 # print only tensor v1 in checkpoint file
13 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
14 
15 # tensor_name:  v1
16 # [ 1.  1.  1.]
17 
18 # print only tensor v2 in checkpoint file
19 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
20 
21 # tensor_name:  v2
22 # [-1. -1. -1. -1. -1.]

 

保存和恢复模型

使用 SavedModel 保存和加载模型-变量、图和图的元数据。SavedModel是一种独立于语言且可恢复的神秘序列化格式,使较高级别的系统和工具可以创建、使用和转换TensorFlow模型。TensorFlow提供了多种与 SaverModel交互的方式,包括 tf.saved_model  API、tf.estimator.Estimator和命令行界面。

 

构建和加载SavedModel

简单保存

创建SavedModel 的最简单的方法使使用 tf.saved_model.simple_save 函数:

1 simple_save(session,
2             export_dir,
3             inputs={"x": x, "y": y},
4             outputs={"z": z})

这样可以配置 SavedModel,使其能够通过 TensorFlow  Serving进行加载,并支持Predict  API。要访问classify API、regress API或者multi-inference API,请使用手动SavedModel builder API或 tf.estimator.Estimator。

手动构建按SavedModel

如果您的用例不在 tf.saved_model.simple_save涵盖范围内,请手动 builder API 创建SaverModel。

tf.saved_model.builder.SavedModelBuilder 类提供了保存多个 MetaGraphDef 的功能。MetaGraph是一种数据流图,并包含相关变量、资源和签名。MetaGraphDef是MetaGraph的协议缓冲区表示法。签名是一组与图有关的输入和输出。

如果需要将资源保存并写入或复制到磁盘,则可以在首次添加 MetaGraphDef时提供这些资源。如果多个 MetaGraphDef 与同名资源相关联,则只保留首个版本。

必须使用用户指定的标签对每个添加到 SavedModel 的 MetaGraphDef进行标注。这些标签提供了一种方法来识别要加载和恢复的特定MetaGraphDef,以及共享的变量和资源子集。这些标签一般会标注MetaGraphDef的功能(例如服务或训练),有时也会标注特定的硬件方面的信息(如GPU)。

例如,以下代码展示了使用MeatGraphDef构建SavedModel的典型方法:

 1 export_dir = ...
 2 ...
 3 builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
 4 with tf.Session(graph=tf.Graph()) as sess:
 5   ...
 6   builder.add_meta_graph_and_variables(sess,
 7                                        [tag_constants.TRAINING],
 8                                        signature_def_map=foo_signatures,
 9                                        assets_collection=foo_assets,
10                                        strip_default_attrs=True)
11 ...
12 # Add a second MetaGraphDef for inference.
13 with tf.Session(graph=tf.Graph()) as sess:
14   ...
15   builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True)
16 ...
17 builder.save()

通过 strip_default_attrs=True确保前向兼容性

只有在操作集合没有变化的情况下,遵循以下指南才能带来向前兼容性。

SavedModelBuilder类允许用户控制在将元图添加到SaverModel软件时,是否必须从NodeDefs剥离默认属性。SavedModelBuilder.add_meta_graph_add_variable和SavedModelBuilder.add_meta_graph 方法都接受控制此行为的布尔标记strip_default_attrs。

如果strip_default_attrs=False,则导出的tf.MetaGraphDef 将在其所有的 tf.NodeDef实例中具有设为默认值的属性。这样会破坏前向兼容性并出现一系列事件,详情请参阅兼容性指南

加载Python版SavedModel

Python版的SavedModel加载器为SavedModel提供了加载和恢复功能。load指令需要以下信息:

  • 要在其中恢复图定义和变量的会话。
  • 用于标识要加载的MetaGraphDef的标签。
  • SavedModel的位置(目录)

加载后,作为特定MetaModelDef的一部分提供的变量、资源和签名子集将恢复到提供的会话中。

1 export_dir = ...
2 ...
3 with tf.Session(graph=tf.Graph()) as sess:
4   tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
5   ...

加载C++版SavedModel

C++版SavedModel加载器提供了一个可从某个路径加载SavedModel的API(同时允许SessionOptions和RunOptions)。您必须指定要加载的与图相关联的标签。SavedModel加载后的版本称为SavedModelBundle,其中包含MetaGraphDef和加载时所在的会话。

const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},&bundle);

 

在TensorFlow Serving中加载和提供SavedMoel

您可以使用TensorFlow Serving Model Sever二进制文件轻松加载和提供SavedModel。请参阅此处说明,了解如何安装服务器,或根据需要创建服务器。

一旦您的Model  Sever就绪,请运行以下内容:

1 tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path

将port和model_name标记设为您所选的值。model_base_path标记应为基本目录,每个版本的模型都放置于以数字命名的子目录中给。如果您的模型只有一个版本,只需如下所示的将其放在子目录中即可:*将模型放入/tmp/model/0001*将model_base_path设为/tmp/model

 

将模型的不同版本存储在共用基本目录的子目录中(以数字命名)。例如,假设基本目录是/tmp/model。如果您的模型只有一个版本,请将其存储在/tmp/model/0001中。如果您的模型有两个版本,请将第二个版本存储在/tmp/model/0002中,以此类推。将 --model-bash_path标记设为基本目录(在本例中为/tmp/model)。TensorFlow Model Sever将在该基本目录的最大编号的子目录中提供模型。

 标准常量

SavedModel为各种用例搭建和加载TensorFlow图提供了灵活性。对于常见的用例,SavedModel的API在Python和C++中提供了一组易于重复使用且在各种工具中共享的常量。

标准MetaGraphDef标签

您可以使用标签组唯一标识保存在SavedModel中的MetaGraphDef。常用标签的子集如下:

标准SignatureDef常量

SignatureDef是一个协议缓冲区,用于定义图中所支持的计算的签名。常用的输入键、输出键和方法名称定义如下:

 

搭配Estimator使用SavedModel

使用CLI检查并执行SavedModel

SavedModel目录的结构

当您以SavedModel格式保存模型时,TensorFlow会自动创建一个由以下子目录和文件组成的SavedModel目录:

1 assets/
2 assets.extra/
3 variables/
4     variables.data-?????-of-?????
5     variables.index
6 saved_model.pb|saved_model.pbtxt

其中:

  • assets 是包含辅助(外部)文件(如词汇表)的子文件夹。资源被复制到SavedModel的位置,并且可以在加载特定的MetaGraphDef时被读取。
  • assets.extra 是一个子文件夹,其中较高级别的库和用户可以添加自己的资源,这些资源与模型共存,但不会被图加载。此子文件夹不由SavedModel库管理。
  • variables 是包含 tf.train.Saver的输出的子文件夹。
  • saved_model.pbsaved_model.pbtxt 是SavedModel协议缓冲区。它作为MetaGraphDef协议缓冲区的图定义。

单个SavedModel可以表示多个图。在这种情况下,SavedModel中所有图共享一组检查点(变量)和资源。例如,下图显示了一个包含三个MetaGraphDef的SavedModel,它们都共享共享同一组检查点和资源:

每组图都与一组特定的标记相关联,可在加载或恢复期间方便您识别。

 

 

参考链接:https://tensorflow.google.cn/guide/saved_model#save_and_restore_variables

 

posted @ 2019-02-11 22:03  Rogn  阅读(1069)  评论(0编辑  收藏  举报