[开源框架]mmdetection3d学习(二):训练流程
模型训练流程
-
从
tools/train.py
开始:- 一通读取 cfg ,初步设置一些基本参数,log 参数;
- build 模型,build 数据集 (有多少个 workflow 就 build 多少个数据集,比如如果 train 的过程中还进行 val 则表示有 2 个 workflow) ;
- 最后调用
mmdet.apis.train_detector
,传入刚才 build 好的 model,datasets,配置参数等。
-
进入
mmdet.apis.train_detector
:- 为每一个 workflow 对应的 dataset , build data_loader ( data_loader 继承自 pytorch 自带的 DataLoader 类,这里先简单理解,其将 dataset 里面 data sample 包装成 data batch ,作为生成器的形式,每次用 for 迭代 load batch ) ;
- 判断是否是分布式训练,分布式训练则用
MMDistributedDataParallel
封装 model,单 GPU 训练则MMDataParallel
; - build optimizer;
- 重头戏: runner ,runner 可以理解为操控整个训练过程的核心。首先,先跳过中间那一堆对 runner hook 的设置,直接看到最后,调用了
runner.run()
,训练从此处开始。
-
runner 是
EpochBasedRunner
类的实例,进入EpochBasedRunner
类的定义,可以看到最主要的是 run 方法:def run(self, data_loaders, workflow, max_epochs, **kwargs): #... while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( 'mode in workflow must be a str, but got {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break epoch_runner(data_loaders[i], **kwargs)
workflow
变量的注释:workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,最后 4 行是重点,根据每个 workflow 的 mode 和 epochs 调用 epochs 次相应的函数,比如:
for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break #when mode == 'train' # `epoch_runner(data_loaders[i], **kwargs)` == self.train(data_loaders[i], **kwargs)
一个 epoch 相当于遍历一遍数据集的所有数据。
接下来看看 EpochBasedRunner.train() :
- 设置基本参数
- 在一些关键节点的前后调用了 hook :
before_train_epoch
,before_train_iter
,after_train_iter
,after_train_epoch
。执行反向传播是在after_train_iter
处。 (先不纠结 hook 是个啥) - data_loader 为生成器,用 for 迭代取出 1 个 batch 的数据,进入逐个 iter 的训练:
- 如果有为该 Runner 指定 batch processor,则调用。
- 否则,直接调用模型的 train_step,传入训练数据。
hook
hook 的作用是对一些中间结果做相应的操作,比如打印 log ,比如在 training 过程中的 evaluation 等等。
下面解析一下配置文件中出现的 TensorboardLoggerHook
先从 EpochBasedRunner 如何使用 hook 看起:
-
EpochBasedRunner.register_hook()
- 注册 hook 到 runner,根据 hook cfg build 相应的 hook 实例,放到 runner 的 hook 队列中。hook 队列是一个优先级队列,优先级可以在传入 hook 的时候指定。
-
EpochBasedRunner.call_hook(fn_name)
- 使用 hook ,根据需要调用的函数名
fn_name
,调用每个 hook 里的同名函数,因为 runner 缓存着中间结果,需要将 runner 作为参数传进去。
- 使用 hook ,根据需要调用的函数名
TensorboardLoggerHook
-
该类主要的作用是将每次 iter 或 epoch 完记录训练结果到 tensorboard (即写到 summary 文件里)
-
TensorboardLoggerHook.after_train_iter(runner)
该函数做了什么?判断是否达到 interval,比如在配置文件中指定了每 50 个 iter 才 log 训练结果,如果达到 50 个 iter,则对 50 个 iter 的结果求平均,再调用自己的 log 函数。 50 个 iter 的结果存放在 runner.log_buffer 里。
-
TensorboardLoggerHook.log(runner)
将 runner.log_buffer 里的结果值,通过 summary_writer 写到 summary 文件。
除了 Logger 这种形式的 hook 之外,还有其他一些功能也以 hook 的形式实现,比如 optimizer 对应的 OptimizerHook
,或者 training 过程中的 eval 也是通过 EvaluationHook
调用。