基于Tensorflow的Faster-Rcnn的断点续训
一、前言
最近在学习目标检测,到github上找了一个开源的Faster-RCNN项目(Tensorflow),项目地址是:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3
根据网上的各种教程,模型训练还算顺利,不过这个项目缺少断点续训的功能。也就是中途误操作导致训练中止,就只能从头开始训练,模型的训练还是需要比较长的时间,没有断点续训不是很方便。因此在原项目的基础上新增了断点续训功能。
二、断点续训
找到项目根目录下的train.py文件,在 last_snapshot_iter = 0 这行代码后新增以下代码块:
ckpt = tf.train.get_checkpoint_state("./default/voc_2007_trainval/default") if ckpt and ckpt.model_checkpoint_path: self.saver.restore(sess,ckpt.model_checkpoint_path) #恢复当前会话sess,将ckpt中的值赋给w和b last_checkpoint = ckpt.model_checkpoint_path #最近模型路径 ins_start = last_checkpoint.index("iter_")+5 ins_end = last_checkpoint.index(".ckpt") last_iter = last_checkpoint[ins_start:ins_end] #最近模型的迭代次数 last_snapshot_iter = int(last_iter)
加完代码之后,训练中止时,执行python train.py,即可自动检测是否断点续训。如果想重新开始训练模型,将 default/voc_2007_trainval/default 目录下的内容删除即可。