如何设计一个合理、灵活的深度学习训练框架
一、如何设计一个Trainer
1.1 深度学习的整体训练流程
基本上而言,你大概会写出如下的代码:
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class Trainer(): def __init__(): # 定义基本参数 self.max_iter = 10000 self.save_iter = 2000 self.log_iter = 1000 # 首先是定义深度学习训练四件套 self.model = create_model() self.optimizer = create_optimizer() self.data_loader = create_dataloader() self.learning_rate_adjuster = create_lr_adjuster() # 为了保存模型,你可能还会定义 saver, 用于模型的存储 self.saver = create_saver() # 为了记录训练模型过程中的相关信息,可能你还需要定义一个 tensorboard writter self.writer = create_tensorboard_writer() def train(self): iteration = 0 for self.iter in range(0, self.max_iter): # 首先来完成训练三部曲 # step1 数据加载 data = next(self.data_loader) # step2 loss 计算 loss , acc , other_info = self.model(data) # step3 反向传播 self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 然后来完成一些训练后的清理工作 iteration += 1 if iteration % self.save_iter == 0: self.saver.save(model) if iteration % self.log_iter == 0: self.writer.log('loss',loss) self.writer.log('acc',acc) self.writer.log('other_info',other_info)
以上就是一个最基本的深度学习训练流程了,你可以在各种各样的项目中看到上面这样一套代码,他们可能不叫这个名字,但是一定是做类似的事情。而今天要介绍的,就是如何如何一步一步把它拆成更细的粒度,并且进行合理的封装
1.2 流程拆分Step1:抽象训练流程
要拆分上面的代码,咱们当然先拆大头 Trainer.train 方法了。在很多项目当中,train 方法都会写的非常的冗长,因为会把所有和训练,和 记录相关的代码都放在train当中。那么在这里,我们可以把Trainer 进行如下的抽象:
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class TrainerBase(): def __init__(self): pass def train(self): iteration = 0 self.before_train() for self.iter in range(0, max_iter): # 首先来完成训练三部曲 # step1 数据加载 self.before_step() self.run_step() self.after_step() self.after_train() class Trainer(TrainerBase): def __init__(self): # 定义基本参数 self.Epoch = 100 self.save_iter = 2000 self.log_iter = 1000 # 首先是定义深度学习训练四件套 self.model = create_model() self.optimizer = create_optimizer() self.data_loader = create_dataloader() self.learning_rate_adjuster = create_lr_adjuster() # 为了保存模型,你可能还会定义 saver, 用于模型的存储 self.saver = create_saver() # 为了记录训练模型过程中的相关信息,可能你还需要定义一个 tensorboard writter self.writer = create_tensorboard_writer() def run_step(self): data = next(self.data_loader) # step2 loss 计算 loss , acc , other_info = self.model(data) # step3 反向传播 self.optimizer.zero_grad() loss.backward() self.optimizer.step() def after_step(self): # 然后来完成一些训练后的清理工作 iteration += 1 if iteration % self.save_iter == 0: self.saver.save(model) if iteration % self.log_iter == 0: self.writer.log('loss',loss) self.writer.log('acc',acc) self.writer.log('other_info',other_info)
在上述的代码中,我们主要做了两点变动:
1、将 TrainerBase 抽象出来 ,并在TrainerBase中定义了一个通用的训练流程 2、在TrainerBase 的子类Trainer中,定义了具体的 run_step 方法 和 after_step 方法
那么,在这样的定义下,当用户希望定义自己的Trainer的时候,他就只需要继承TrainerBase ,并且实现自己的 beforre_train 、after_train 、 beforestep 、after_step、run_step 方法即可。
1.3 流程拆分Step2:使用Hook执行任务,进一步抽象代码
在上述拆分的基础之上,让我们进一步的来关注一下 after_train 中的这段代码:
class Trainer(TrainerBase): def __init__(self): .... def after_step(self): # 然后来完成一些训练后的清理工作 iteration += 1 if iteration % self.save_iter == 0: self.saver.save(model) if iteration % self.log_iter == 0: self.writer.log('loss',loss) self.writer.log('acc',acc) self.writer.log('other_info',other_info)
在after_step 中,你的很多的”组件“都会在这里完成他们的任务。他们会每隔一定的步长执行一下自己的任务。
为了进一步的对代码进行一些抽象,把每一个”任务“都定义为一个个独立的”Hook“,每一个hook 都会实现自己的 beforre_train 、after_train 、 before_step 、after_step、run_step 方法:
class HookBase: def before_train(self): """ Called before the first iteration. """ pass def after_train(self): """ Called after the last iteration. """ pass def before_step(self): """ Called before each iteration. """ pass def after_step(self): """ Called after each iteration. """ pass
随后,基于Hook,就可以对训练代码进行进一步的抽象:
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class Saver(HookBase): def __init__(self, save_iter): self.save_iter = save_iter def after_step(self): if self.trainer.iter % self.save_iter == 0: save_model(self.trainer.model) class Writer(HookBase): def __init__(self,write_iter): self._debug_info = {} self.write_iter = write_iter self.writer = TensorboardWriter(...) def before_step(self): self._debug_info = {} def after_step(self): loss = self._debug_info['loss'] self.writer.write(loss) class Trainer(TrainerBase): def __init__(self): self.hooks : List[HookBase] = self.register_hooks() def register_hooks(self): self.hooks = [] self.hooks.append(Saver(save_iter)) self.hooks.append(Writer(write_iter)) for h in hooks: assert isinstance(h, HookBase) h.trainer = weakref.proxy(self) def before_step(self): for hook in self.hooks: hook.before_step() def run_step(self): self.iter += 1 data = next(self.data_loader) # step2 loss 计算 loss , acc , other_info = self.model(data) # step3 反向传播 self.optimizer.zero_grad() loss.backward() self.optimizer.step() def after_step(self): for hook in self.hooks: hook.after_step()
现在再让我们梳理一下代码中关于hooks的逻辑:
- 1、Trainer在初始化时注册一系列的hooks , 每个hook 可以完成一个工作
- 2、注册hooks 的时候,通过 h.trainer = weakref.proxy(self) 把自身变为 hooks的属性,使 得hook中可以通过 h.trainer.iter 获取trainer内部记录的一些训练状态相关的信息
- 3、每个hook都会有自己的一些列参数,这样,如 save_iter , write_iter 这样的信息就不是直接注册在trainer 中,而是记录在每个hook类自己的内容,保证了Trainer代码具有高度的扩展性
1.4 流程拆分Step3:细化每个类的功能
在上一步中,我们已经将所有非训练相关的任务通过Hook的方式从Trainer 的代码中分离出去了,那么此时的Trainer类大概是如下形式:
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class Trainer(TrainerBase): def __init__(self): self.hooks : List[HookBase] = self.register_hooks() ... def register_hooks(self): self.hooks = [] self.hooks.append(Saver(save_iter)) self.hooks.append(Writer(write_iter)) for h in hooks: assert isinstance(h, HookBase) h.trainer = weakref.proxy(self) def before_step(self): for hook in self.hooks: hook.before_step() def run_step(self): self.iter += 1 data = next(self.data_loader) # step2 loss 计算 loss , acc , other_info = self.model(data) # step3 反向传播 self.optimizer.zero_grad() loss.backward() self.optimizer.step() def after_step(self): for hook in self.hooks: hook.after_step()
但是在上述写法中,注意到 run_step 方法是会随着不同的数据加载、不同的模型定义、损失定义而改变的,因此,有必要再进行一次抽象,让类的功能更加的泛化
二、Trainer设计方法
2.1 TrainerBase
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class HookBase: def before_train(self): pass def after_train(self): pass def before_step(self): pass def after_step(self): pass class TrainerBase: def __init__(self): self._hooks = [] def register_hooks(self, hooks): hooks = [h for h in hooks if h is not None] for h in hooks: assert isinstance(h, HookBase) h.trainer = weakref.proxy(self) self._hooks.extend(hooks) def train(self, start_iter: int, max_iter: int): self.iter = self.start_iter = start_iter self.max_iter = max_iter with EventStorage(start_iter) as self.storage: try: self.before_train() for self.iter in range(start_iter, max_iter): self.before_step() self.run_step() self.after_step() finally: self.after_train() def before_train(self): for h in self._hooks: h.before_train() def after_train(self): for h in self._hooks: h.after_train() def before_step(self): for h in self._hooks: h.before_step() def after_step(self): for h in self._hooks: h.after_step() # this guarantees, that in each hook's after_step, storage.iter == trainer.iter self.storage.step() def run_step(self): raise NotImplementedError
2.2 SimpleTrainer
作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class SimpleTrainer(TrainerBase): def __init__(self, model, data_loader, optimizer): super().__init__() # 注意到为了灵活性,这里仍然没有定义 data_loader , model 和 optimizer # 仍然是采用了 加载的方式,而真正定义这些的类,会在下一节中介绍 model.train() self.model = model self.data_loader = data_loader self._data_loader_iter = iter(data_loader) self.optimizer = optimizer def run_step(self): # 通过next 方法获取数据 data = next(self._data_loader_iter) data_time = time.perf_counter() - start # 执行前向代码 loss_dict = self.model(data) losses = sum(loss for loss in loss_dict.values()) self._detect_anomaly(losses, loss_dict) metrics_dict = loss_dict metrics_dict["data_time"] = data_time self._write_metrics(metrics_dict) # 进行反向传播 self.optimizer.zero_grad() losses.backward() self.optimizer.step() def _detect_anomaly(self, losses, loss_dict): if not torch.isfinite(losses).all(): raise FloatingPointError( "Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format( self.iter, loss_dict ) )
其次,在对应到具体项目,具体任务的时候,再定义一个继承SimpleTrainer的类 DefaultTrainer,来实现模型的创建、数据的加载等基本方法
2.3 DefaultTrainer
具体到目标检测任务时,真正的Trainer其实是定义在 同一目录下的DefaultTrainer作者:emiya 链接:https://zhuanlan.zhihu.com/p/97326458 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 class DefaultTrainer(SimpleTrainer): def __init__(self, cfg): model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False ) super().__init__(model, data_loader, optimizer) self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks()) ... def build_hooks(self): cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), ... ] if comm.is_main_process(): ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) if comm.is_main_process(): # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers())) return ret def build_writers(self): return [ # It may not always print what you want to see, since it prints "common" metrics only. CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] def train(self): super().train(self.start_iter, self.max_iter) if hasattr(self, "_last_eval_results") and comm.is_main_process(): verify_results(self.cfg, self._last_eval_results) return self._last_eval_results
在这份代码中,最核心的部分就是build_model, build_dataloader 两个内容,而它们会在之后介绍模型的部分详细展开。至于到这里注册的各种Hooks 以及其能够实现的功能,我们也会在接下来的章节中进行详尽的介绍。
三、如何自定义自己的Trainer
首先需要说的是,为什么你需要自己的一个trainer呢?基本上而言,我自己实践下来可能是基于如下的几个需求:
2、针对自己的训练任务,有自己的评价方法,而希望边训边测试,所以需要定义自己的EvalHook,并在Trainer中调用,这时候就需要自定义 自己的 register_hook 方法
3、在数据加载时,需要对数据进行debug,因此需要自定义自己的run_step 方法
4、因为进行简单调试的时候,经常容易训练出nan,因此需要定义自己的_detect_anomaly 方法
class TextTrainer(DefaultTrainer): def build_train_loader(self, cfg): # 重写这个方法就好 text_mapper = BoundaryMapper(cfg) data_loader = build_detection_train_loader(cfg, text_mapper) return data_loader @classmethod def build_test_loader(cls, cfg, dataset_name): text_mapper = BoundaryTestMapper(cfg) test_data_loader = build_detection_test_loader(cfg, dataset_name, mapper=text_mapper) return test_data_loader