BERT中文 添加 early_stop

Step1:建一个hook

early_stopping_hook = tf.contrib.estimator.stop_if_no_decrease_hook(
            estimator=estimator,
            metric_name='eval_loss',
            max_steps_without_decrease=FLAGS.max_steps_without_decrease,
            eval_dir=None,
            min_steps=0,
            run_every_secs=None,
            run_every_steps=FLAGS.save_checkpoints_steps)

 

Step2:加到estimator.train里

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[early_stopping_hook])

 

posted @ 2019-01-15 16:59  cup_leo  阅读(1317)  评论(0编辑  收藏  举报