训练高级会话函数
训练以及高级会话函数
主训练逻辑
我们将在cifar_train.py文件实现主要训练逻辑。在这里我们将使用一个新的会话函数,叫tf.train.MonitoredTrainingSession
优点: 1、它自动的建立events文件、checkpoint文件,以记录重要的信息。 2、可以定义钩子函数,可以自定义每批次的训练信息,训练的限制等等
注意:在这个里面我们需要添加一个全局步数,这个步数是每批次训练的时候进行+1计数,内部使用。
代码如下:
import tensorflow as tf import cifar_model import time from datetime import datetime def train(): # 在图中进行训练 with tf.Graph().as_default(): # 定义全局步数,必须得使用这个,否则会出现StopCounterHook错误 global_step = tf.contrib.framework.get_or_create_global_step() # 获取数据 image, label, label_1 = cifar_model.input() # 通过模型进行类别预测 y_logit = cifar_model.inference(image) # 计算损失 loss = cifar_model.total_loss(label, y_logit) # 进行优化器减少损失 train_op, accuracy = cifar_model.train(loss, label, y_logit, global_step) # 通过钩子定义模型输出 class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss, float(accuracy.eval())) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % 10 == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = 10 * 10 / duration sec_per_batch = float(duration / 10) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir="./cifartrain/train", hooks=[tf.train.StopAtStepHook(last_step=500),# 定义执行的训练轮数也就是max_step,超过了就会报错 tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=False)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) def main(argv): train() if __name__ == "__main__": tf.app.run()