tensorflow-记一次global step相关问题的排查

  在调试tensorflow分布式训练代码时,遇到一个诡异的错误:Global step should be created to use StopAtStepHook.

 

  错误发生在以下代码处:

  

stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.total_steps)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir = train_dir, save_steps = 1000, saver = saver)
hooks = [stop_hook, checkpoint_hook]
with tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=None, hooks=hooks):
    while not sess.should_stop():
        _, step, accuracy = sess.run( [optimizer, global_step, accuracy], feed_dict = feed)

  报错显示:RunTimeError: Global step should be created to use StopAtStepHook.

  看似是因为没有定义global_step。但代码里其实定义了,所以我去查看了源码:

  

class StopAtStepHook(session_run_hook.SessionRunHook):
    def begin(self):
        self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use StopAtStepHook.")

  是在StopAtStepHook类的begin方法处出的错。该方法在session使用前被调用,作用就是获取global_step

 

  再去查看training_util._get_or_create_global_step_read(),详细代码就不赘述了,结论是它靠寻找name为“global step”的变量,来寻找global_step,而我之前自己定义的global_step没有起名字。。。

 

  经验教训:像global_step这种变量,最好使用默认的方法:tf.train.get_or_create_global_step()

  

posted @ 2018-11-28 16:57  爱斯特拉冈  阅读(4379)  评论(0)    收藏  举报