[开源框架]mmdetection3d学习(二):训练流程

模型训练流程

  1. tools/train.py 开始:

    • 一通读取 cfg ,初步设置一些基本参数,log 参数;
    • build 模型,build 数据集 (有多少个 workflow 就 build 多少个数据集,比如如果 train 的过程中还进行 val 则表示有 2 个 workflow) ;
    • 最后调用 mmdet.apis.train_detector ,传入刚才 build 好的 model,datasets,配置参数等。
  2. 进入 mmdet.apis.train_detector

    • 为每一个 workflow 对应的 dataset , build data_loader ( data_loader 继承自 pytorch 自带的 DataLoader 类,这里先简单理解,其将 dataset 里面 data sample 包装成 data batch ,作为生成器的形式,每次用 for 迭代 load batch ) ;
    • 判断是否是分布式训练,分布式训练则用 MMDistributedDataParallel 封装 model,单 GPU 训练则 MMDataParallel
    • build optimizer;
    • 重头戏: runner ,runner 可以理解为操控整个训练过程的核心。首先,先跳过中间那一堆对 runner hook 的设置,直接看到最后,调用了 runner.run() ,训练从此处开始。
  3. runner 是 EpochBasedRunner 类的实例,进入 EpochBasedRunner 类的定义,可以看到最主要的是 run 方法:

     def run(self, data_loaders, workflow, max_epochs, **kwargs):
            
        #...
        while self.epoch < max_epochs:
                for i, flow in enumerate(workflow):
                    mode, epochs = flow
                    if isinstance(mode, str):  # self.train()
                        if not hasattr(self, mode):
                            raise ValueError(
                                f'runner has no method named "{mode}" to run an '
                                'epoch')
                        epoch_runner = getattr(self, mode)
                    else:
                        raise TypeError(
                            'mode in workflow must be a str, but got {}'.format(
                                type(mode)))
    
                    for _ in range(epochs):
                        if mode == 'train' and self.epoch >= max_epochs:
                            break
                        epoch_runner(data_loaders[i], **kwargs)
    

    workflow 变量的注释:

    workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means
    running 2 epochs for training and 1 epoch for validation,

    最后 4 行是重点,根据每个 workflow 的 mode 和 epochs 调用 epochs 次相应的函数,比如:

                    for _ in range(epochs):
                        if mode == 'train' and self.epoch >= max_epochs:
                            break
                        #when mode == 'train'
                        # `epoch_runner(data_loaders[i], **kwargs)` ==
                        self.train(data_loaders[i], **kwargs)
    

    一个 epoch 相当于遍历一遍数据集的所有数据。

    接下来看看 EpochBasedRunner.train() :

    • 设置基本参数
    • 在一些关键节点的前后调用了 hook : before_train_epoch, before_train_iter, after_train_iter, after_train_epoch 。执行反向传播是在 after_train_iter 处。 (先不纠结 hook 是个啥)
    • data_loader 为生成器,用 for 迭代取出 1 个 batch 的数据,进入逐个 iter 的训练:
      • 如果有为该 Runner 指定 batch processor,则调用。
      • 否则,直接调用模型的 train_step,传入训练数据。

hook

hook 的作用是对一些中间结果做相应的操作,比如打印 log ,比如在 training 过程中的 evaluation 等等。

下面解析一下配置文件中出现的 TensorboardLoggerHook

先从 EpochBasedRunner 如何使用 hook 看起:

  • EpochBasedRunner.register_hook()

    • 注册 hook 到 runner,根据 hook cfg build 相应的 hook 实例,放到 runner 的 hook 队列中。hook 队列是一个优先级队列,优先级可以在传入 hook 的时候指定。
  • EpochBasedRunner.call_hook(fn_name)

    • 使用 hook ,根据需要调用的函数名 fn_name ,调用每个 hook 里的同名函数,因为 runner 缓存着中间结果,需要将 runner 作为参数传进去。

TensorboardLoggerHook

  • 该类主要的作用是将每次 iter 或 epoch 完记录训练结果到 tensorboard (即写到 summary 文件里)

  • TensorboardLoggerHook.after_train_iter(runner)

    该函数做了什么?判断是否达到 interval,比如在配置文件中指定了每 50 个 iter 才 log 训练结果,如果达到 50 个 iter,则对 50 个 iter 的结果求平均,再调用自己的 log 函数。 50 个 iter 的结果存放在 runner.log_buffer 里。

  • TensorboardLoggerHook.log(runner)

    将 runner.log_buffer 里的结果值,通过 summary_writer 写到 summary 文件。

除了 Logger 这种形式的 hook 之外,还有其他一些功能也以 hook 的形式实现,比如 optimizer 对应的 OptimizerHook,或者 training 过程中的 eval 也是通过 EvaluationHook 调用。

posted @ 2020-08-26 12:36  lunaY  阅读(3493)  评论(0编辑  收藏  举报