快速实现知识蒸馏算法,使用 MMRazor 就够啦!

前言 MMRazor 是 OpenMMLab 生态的面向模型压缩的开源算法库,目前主要涵盖了知识蒸馏、剪枝、NAS 三类算法。本文介绍了如何使用 MMRazor 实现知识蒸馏。

本文转载自OpenMMLab

作者 | 带来新知识的

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

计算机视觉入门1v3辅导班

代码:https://github.com/open-mmlab/mmrazor/tree/dev-1.x

我们接下来将从 MMRazor 知识蒸馏框架介绍,基于 MMRazor 的知识蒸馏实战教程两个方面展开分享。

 

1. MMRazor 知识蒸馏框架介绍

知识蒸馏(Knowledge Distillation,简记为 KD)是一种经典的模型压缩方法,核心思想是通过引导轻量化的学生模型“模仿”性能更好、结构更复杂的教师模型(或多模型的 ensemble),在不改变学生模型结构的情况下提高其性能。

2015 年 Hinton 团队提出的基于“软标签”(response-based)的知识蒸馏技术(一般将该文算法称为vanilla-KD)掀起了相关研究热潮,其后基于“特征”(feature-based)和基于“关系”(relation-based)的KD算法被陆续提出。

由于知识蒸馏的过程可以理解为一个,获取学生网络和教师网络指定蒸馏位点的输出特征并计算蒸馏 loss 的过程。因此,实现一个蒸馏算法往往会有以下几点需求:

  1. 获取学生网络和教师网络指定蒸馏位点的输出特征,例如某个 nn.Module,某个类方法或是某个函数的输入输出信息。在 MMRazor 中,我们通过 Recorder 组件实现。
  2. 在学生网络前向传播过程中,某一中间输出需要被教师网络对应输出覆盖。例如,在 LAD 中,学生网络训练过程中需要使用教师网络的 label assign 结果覆盖掉自身结果。在 MMRazor 中,我们通过 Deliver 组件实现。
  3. 一个蒸馏算法中可能会有多个蒸馏 loss 联合作用。而某个蒸馏 loss 需要的输入可能来自于不同 recorder,也可能来自于某一个 recorder 获取的多个数据中的若干个。因此,我们需要利用 loss_forward_mappings 数据结构从 Recorder 获取的输出特征中筛选得到蒸馏 loss 需要的部分,并利用 connectors 模块进行后处理(例如,对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算)。在 MMRazor 中,这一系列功能我们通过 ConfigurableDistiller 来统一管理。
  4. 进一步,为了解决特定问题或面向特定场景,知识蒸馏算法本身又可以细分为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD 等等。在 MMRazor中,我们通过设计不同的high level 的 Algorithm 组件以支持不同类型的蒸馏算法对应的 pipeline。

上述 4 个组件在 MMRazor 中的位置如下图所示。接下来,我们会详细介绍各个组件的设计动机和使用方法。

1.1 Recorder

如下图所示,Recorder 是一个上下文管理器,用于在模型前项传播过程中记录各种中间结果。同时,它还可以用来获取一些特定位点的数据,用于可视化分析或其他你想要的功能。为了适应更多的需求,我们在 MMRazor 中实现了多种类型的 Recorder 来获得不同类型的中间结果,它们由 RecorderManager 统一管理。

目前,我们支持了 7 类 Recorder,如下表所示:

Recorder 名称 描述
ModuleOutputsRecorder/ ModuleInputsRecorder 获取某个 torch.nn.Module 的 输出 / 输入 结果
FunctionOutputsRecorder / FunctionInputsRecorder 获取模型中使用到的某个函数的 输出 / 输入 结果
MethodOutputsRecorder / MethodInputsRecorder 获取模型中某个类的某个方法的 输出 / 输入 结果
ParameterRecorder 获取某个 torch.nn.Module 的模型参数

接下来我们以 ModuleOutputsRecorder 为例,为大家介绍下 Recorder 的使用方法。

 

ModuleOutputsRecorder

获取 nn.Module 的输入输出相对会比较容易,因为可以通过为 nn.Module 挂上 PyTorch 原生的 forward hook 来实现。由于这两种 Recorder 的使用方法非常类似,我们以 ModuleOutputsRecorder 为例来介绍它们是如何工作的:

注意:所有的 Recorder 在使用前都需要执行 initialize 方法。

 

RecorderManager

RecorderManager 同样是上下文管理器,可用于管理各种类型的 Recorder。

在 RecorderManager 的帮助下,我们可以用尽可能少的代码管理几个不同的记录器,这减少了出错的可能性。

1.2 Deliver

Deliver 工具是 MMRazor 专门为处理蒸馏算法中涉及的一些特殊情况而设计的,它在教师模型和学生模型之间转移并覆盖掉一些中间结果,如下图所示:

Deliver 并不像 Recorder 那样,基本在每个蒸馏算法中都会使用,但对于一些算法,它是必不可少的。例如,在 LAD 中,学生网络需要直接获取教师网络的 label assignment 信息,我们可以如下配置 Deliver:

Deliver 的可配置性,让我们可以不需要对源代码进行 hardcode 修改。

 

1.3 ConfigurableDistiller

ConfigurableDistiller 是一个功能强大的工具,可以在不修改教师或学生模型代码的情况下实现大多数蒸馏算法。它可以通过 Recorder 以一种 hack 的方式获得模型的各种中间结果。同样,它可以使用 Delivery 以 hack 的方式使用老师的中间结果来覆盖学生的中间结果。

这里的 student_recorders、teacher_recorders、distill_deliveries 上文刚刚介绍。distill_losses 是蒸馏时用到的蒸馏损失函数,可以是一个或多个。

这里引入了两个新概念:connectors 和 loss_forward_mappings。

 

Connectors

知识蒸馏算法往往分为 reponse-based、feature-based 和 relation-based 三类。其中,feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象,最简单的 L2 损失如下所示:

实现特征对齐功能的模块是 feature-based KD 算法的核心模块(MMRazor 中称之为 connector),也是很多算法的重点研究对象。如针对教师 connector 进行预训练的 Factor Transfer 算法;以二值化形式筛选教师和学生原始特征的 AB 算法;将特征值转换为注意力值的 AT 算法等。OFD 对各相关算法进行总结,研究了特征位置、connector 的构成、损失函数等因素对蒸馏性能、信息损失的影响,汇总表如下所示:

上面提到的 FitNets、Factor Transfer、AB、AT Loss(AT 算法与蒸馏最相关的损失计算部分)、OFD 等算法均被集成到了 MMRazor 算法库中,且核心模块 connector 被单独抽象出来作为可配置组件,非常便于大家进行“算法魔改”(如为 FitNets 算法配置上 Factor Transfer 的 connector 并计算 AT Loss)。

 

loss_forward_mappings

通过 Recorder 组件获得模型中间结果后,可以通过配置 loss_forward_mappings 来进一步指定不同的蒸馏 loss 的输入参数是什么。

下面代码表示我们只使用一个蒸馏 loss,我们把它称为 loss_neck,它实际上是一个 L2Loss,loss weight=5。

我们分别设置 student_recorders 和 teacher_recorder 希望获取学生和教师网络的 'neck.gap' 这一 module 的输出,输出特征命名为 feat。那么在配置 loss_forward_mappings 时可以看到 L2Loss 的 forward 方法有两个输入参数,分别为 s_feature 和 t_feature。而 s_feature 和 t_feature 又分别来自于名字为 feat 的学生和教师网络中间特征。

1.4 Algorithm

最后,Distill Algorithm 负责控制蒸馏的一整个 pipeline。MMRazor 实现了多类知识蒸馏算法,我们以最经典的,使用单个教师网络蒸馏单个学生网络的 SingleTeacherDistill 算法为例。下面代码展示了其在初始化时需要传入的参数:

2. 基于 MMRazor 的知识蒸馏实战教程

我们首先以 KD 算法使用 ResNet34 蒸馏 ResNet18 为例,介绍如何使用 MMRazor 实现基础蒸馏算法。

KD

步骤一:设计 Distiller 相关配置文件

传统 KD 算法的知识提取和 loss 计算过程非常简洁,只需获取学生网络和教师网络的输出 logits 并计算 KL 散度即可。Distiller 组件的配置文件主要分为四个部分,如下方代码所示。

  1. student_recorders 和 teacher_recorders 表示我们需要分别记录学生和教师网络中某一个 nn.Module 的输出,这个 nn.Module 的 module name 是 head.fc,记录的数据我们将其命名为 fc。
  2. 在整个蒸馏过程中,我们只使用了一种损失函数,因此 distill_losses 中只包含一组 key 和 value。key 是 loss_kl,表示将这个蒸馏 loss 命名为 loss_kl,value 则是蒸馏 loss 对应的配置文件,表示我们用的蒸馏 loss 是 KLDivergence,超参数 temperature 温度为 1,loss weigh 为 5。
  3. loss_forward_mappings 指定每个 loss module 的输入数据是什么。在本示例中,loss module 有两个输入:preds_S 和 preds_T,分别表示学生和教师网络输出 logits,它们需要跟 KLDivergence loss module 中 forward 方法传入参数保持一致。另外,我们通过 from_student 和 recorder 两个字段判断从student_recorders 还是 teacher_recorders 中读取哪个值。dict(from_student=True, recorder='fc') 表示读取 student_recorders 中名字为 fc 的数据。

 

步骤二:设计 Algorithm 相关配置文件

算法层面需要进行以下配置。architecture 和 teacher 指定了学生/教师网络的网络配置,teacher_ckpt 指定教师网络的预训练权重的路径,如果使用 MMClassification 提供的预训练参数,也可直接在 teacher 对应的字典中的 pretrained 的值设为 True。

其余配置只需按照默认设置,即:teacher_trainable=False 表示蒸馏过程中教师网络是不可学习的;teacher_norm_eval=True 表示蒸馏过程中教师网络的 norm module 全程都处在 eval 模式下;calculate_student_loss=True 表示学生网络除了受教师网络监督外,还受到 ground truth 的监督。

步骤三:设计其他配置文件

最后,我们还需要定义数据集、优化器等其他配置文件,这部分直接引入 MMClassification 定义好的配置文件即可,代码如下,其中 val_cfg 定义了蒸馏过程中 evaluation 的 pipeline:

OFD

 

接下来我们以 OFD 算法为例,介绍如何使用 MMRazor 实现稍微复杂一些的算法。

步骤一:设计 Distiller 相关配置文件

上方代码是 OFD 算法 distiller 部分对应的配置文件。student_recorders 和 teacher_recorders 与上述 KD 算法类似,分别获取学生和教师网络三个 nn.module 的输出并命名为 bb_1,bb_2,bb_3。distill_losses定义了三个蒸馏 loss,loss weight 分别为 0.25,0.5 和 1.0。与传统 KD 算法不同,connectors定义了 OFD 中用到的 connector 结构 —— convbn layer。前文介绍过,MMRazor 使用 connectors 模块实现特征对齐功能,即:在获取 loss module 所需输入数据后,输入 loss module 前,数据需要经过 connector 处理。

最后 loss_forward_mappings 也会比 KD 中复杂些,我们以下方代码为例。代码中 distill_losses定义了 名为 loss_1 的蒸馏损失函数对应的是 loss_weight=0.25 的 OFDLoss。其中 OFDLoss loss module 的输入参数有两个,分别是 s_feature 和 t_feature。s_feature 来自名为 bb_1 的 student_recorders ,需要经过名为 loss_1_sfeat 的 connector 进行特征后处理。从下方 connectors 部分的定义可知,名为loss_1_sfeat 的 connector 是一个 ConvModule,输入输出通道数分别为 32 和 64。同理可得 t_feature。

步骤二 / 三:设计 Algorithm 相关配置文件 / 设计其他配置文件

步骤二、三与传统 KD 算法类似。至此,我们通过修改配置文件实现了 OFD 算法。

 

总结

本文介绍了 MMRazor 对知识蒸馏算法的设计框架,并列举了两个简单例子来介绍如何使用 MMRazor 开发知识蒸馏算法。希望本文可以帮助到大家,也期待大家的使用和批评指正。

我们非常欢迎大家:

  1. 提出使用过程中遇到的问题,包括但不限于 bug、框架设计优化建议、希望后续 MMRazor 新增某些功能、算法等;
  2. 在 MMRazor 中复现某个算法或某类算法 pipeline(优秀的复现会掉落实习机会哦,爆率真的很高);
  3. 帮 MMRazor 进行宣传,增加使用者的数量(优秀的宣传大使同样会掉落实习机会哦)等。

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

计算机视觉入门1v3辅导班

【技术文档】《从零搭建pytorch模型教程》122页PDF下载

QQ交流群:444129970。群内有大佬负责解答大家的日常学习、科研、代码问题。

其它文章

 

原来Transformer就是一种图神经网络,这个概念你清楚吗?

 

3D目标检测中点云的稀疏性问题及解决方案

 

一文深度剖析扩散模型究竟学到了什么?

 

OpenMMLab教程【零】OpenMMLab介绍与安装

 

代码实战:YOLOv5实现钢材表面缺陷检测

 

TensorRT教程(六)使用Python和C++部署YOLOv5的TensorRT模型

 

知识蒸馏的迁移学习应用

 

超全汇总 | 计算机视觉/自动驾驶/深度学习资料合集!

 

高精度语义地图构建的一点思考

 

一文看尽深度学习中的各种注意力机制

 

点云分割训练哪家强?监督,弱监督,无监督还是半监督?

 

姿态估计端到端新方案 | DirectMHP:用于全范围角度2D多人头部姿势估计

 

深度理解变分自编码器(VAE) | 从入门到精通

 

用于超大图像的训练策略:Patch Gradient Descent

 

CV小知识讨论与分析(5)到底什么是Latent Space?

 

【免费送书活动】关于语义分割的亿点思考

 

新方案:从错误中学习,点云分割中的自我规范化层次语义表示

 

经典文章:Transformer是如何进军点云学习领域的?

 

CVPR 2023 Workshop | 首个大规模视频全景分割比赛

 

如何更好地应对下游小样本图像数据?不平衡数据集的建模的技巧和策

 

Transformer交流群

 

 

posted @ 2023-02-28 20:10  CV技术指南(公众号)  阅读(495)  评论(0编辑  收藏  举报