Tensorflow基础教程9:常用模块 tf.train.Checkpoint 之变量的保存与恢复

  目录

  tf.train.Checkpoint

  保存参数

  载入之前保存的参数

  保存变量+恢复变量

  `tf.train.Checkpoint` VS `tf.train.Saver`

  实例

  使用 `tf.train.CheckpointManager` 删除旧的 Checkpoint 以及自定义文件编号

  Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel 。

  tf.train.Checkpoint

  很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle 存储 model.variables。但不幸的是,TensorFlow 的变量类型 ResourceVariable 并不能被序列化。

  好在 TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save() 和 restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:

  checkpoint = tf.train.Checkpoint(model=model)

  这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model 的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer ,我们可以这样写:

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

  这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

  保存参数

  接下来,当模型训练完成需要保存的时候,使用:

  checkpoint.save(save_path_with_prefix)

  就可以。 save_path_with_prefix 是保存文件的目录 + 前缀。

  例如,在源代码目录建立一个名为 save 的文件夹并调用一次 checkpoint.save('./save/model.ckpt') ,我们就可以在 save 目录下发现名为 checkpoint 、 model.ckpt-1.index 、 model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。

  载入之前保存的参数

  当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

  model_to_be_restored = MyModel() # 待恢复参数的同一模型

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”

  checkpoint.restore(save_path_with_prefix_and_index)

  即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号。

  例如,调用 checkpoint.restore('./save/model.ckpt-1') 就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。

  当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。

  例如如果 save 目录下有 model.ckpt-1.index 到 model.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10 。

  保存变量+恢复变量

  总体而言,恢复与保存变量的典型代码框架如下:

  # train.py 模型训练阶段

  model = MyModel()

  # 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)

  checkpoint = tf.train.Checkpoint(myModel=model)

  # ...(模型训练代码)

  # 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)

  checkpoint.save('./save/model.ckpt')

  # test.py 模型使用阶段

  model = MyModel()

  checkpoint = tf.train.Checkpoint(myModel=model) # 实例化Checkpoint,指定恢复对象为model

  checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数

  # 模型使用代码

  tf.train.Checkpoint VS tf.train.Saver

  tf.train.Checkpoint 与以前版本常用的 tf.train.Saver 相比,强大之处在于其支持在即时执行模式下 “延迟” 恢复变量。

  具体而言,当调用了 checkpoint.restore() ,但模型中的变量还没有被建立的时候,Checkpoint 可以等到变量被建立的时候再进行数值的恢复。即时执行模式下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在 train.py 调用 tf.keras.Model 的 save_weight() 方法保存 model 的参数,并在 test.py 中实例化 model 后立即调用 load_weight() 方法,就会出错,只有当调用了一遍 model 之后再运行 load_weight() 方法才能得到正确的结果。可见, tf.train.Checkpoint 在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint 同时也支持图执行模式。

  实例

  最后提供一个实例,以前章的 多层感知机模型 为例展示模型变量的保存和载入:

  import tensorflow as tf

  import numpy as np

  import argparse

  from zh.model.mnist.mlp import MLP

  from zh.model.utils import MNISTLoader

  parser = argparse.ArgumentParser(description='Process some integers.')

  parser.add_argument('--mode', default='train', help='train or test')

  parser.add_argument('--num_epochs', default=1)

  parser.add_argument('--batch_size', default=50)

  parser.add_argument('--learning_rate', default=0.001)

  args = parser.parse_args()

  data_loader = MNISTLoader()

  def train():

  model = MLP()

  optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)

  num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 实例化Checkpoint,设置保存对象为model

  for batch_index in range(1, num_batches+1):

  X, y = data_loader.get_batch(args.batch_size)

  with tf.GradientTape() as tape:

  y_pred = model(X)

  loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)

  loss = tf.reduce_mean(loss)

  print("batch %d: loss %f" % (batch_index, loss.numpy()))

  grads = tape.gradient(loss, model.variables)

  optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

  if batch_index % 100 == 0: # 每隔100个Batch保存一次

  path = checkpoint.save('./save/model.ckpt') # 保存模型参数到文件

  print("model saved to %s" % path)

  def test():

  model_to_be_restored = MLP()

  # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)

  checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数

  y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)

  print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))

  if __name__ == '__main__':

  if args.mode == 'train':

  train()

  if args.mode == 'test':

  test()

  在代码目录下建立 save 文件夹并运行代码进行训练后,save 文件夹内将会存放每隔 100 个 batch 保存一次的模型变量数据。在命令行参数中加入 --mode=test 并再次运行代码,将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能,可以直接获得 95% 左右的准确率。

  使用 tf.train.CheckpointManager 删除旧的 Checkpoint 以及自定义文件编号

  在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:

  在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;

  Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 Batch 的编号作为文件编号)。

  这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:

  checkpoint = tf.train.Checkpoint(model=model)

  manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

  此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ), max_to_keep 为保留的 Checkpoint 数目。

  在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100) 。

  以下提供一个实例,展示使用 CheckpointManager 限制仅保留最后三个 Checkpoint 文件,并使用 batch 的编号作为 Checkpoint 的文件编号。

  import tensorflow as tf大连做人流哪家好 http://mobile.dlrlyy.com/

  import numpy as np

  import argparse

  from zh.model.mnist.mlp import MLP

  from zh.model.utils import MNISTLoader

  parser = argparse.ArgumentParser(description='Process some integers.')

  parser.add_argument('--mode', default='train', help='train or test')

  parser.add_argument('--num_epochs', default=1)

  parser.add_argument('--batch_size', default=50)

  parser.add_argument('--learning_rate', default=0.001)

  args = parser.parse_args()

  data_loader = MNISTLoader()

  def train():

  model = MLP()

  optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)

  num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model)

  # 使用tf.train.CheckpointManager管理Checkpoint

  manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)

  for batch_index in range(1, num_batches):

  X, y = data_loader.get_batch(args.batch_size)

  with tf.GradientTape() as tape:

  y_pred = model(X)

  loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)

  loss = tf.reduce_mean(loss)

  print("batch %d: loss %f" % (batch_index, loss.numpy()))

  grads = tape.gradient(loss, model.variables)

  optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

  if batch_index % 100 == 0:

  # 使用CheckpointManager保存模型参数到文件并自定义编号

  path = manager.save(checkpoint_number=batch_index)

  print("model saved to %s" % path)

  def test():

  model_to_be_restored = MLP()

  checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)

  checkpoint.restore(tf.train.latest_checkpoint('./save'))

  y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)

  print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))

  if __name__ == '__main__':

  if args.mode == 'train':

  train()

  if args.mode == 'test':

  test()

posted @ 2021-03-05 16:59  tiana_Z  阅读(783)  评论(0编辑  收藏  举报