【论文笔记】(知识蒸馏)Distilling the Knowledge in a Neural Network

摘要

模型平均可以提高算法的性能,但是计算量大且麻烦,难以部署给用户。《模型压缩》这篇论文中表明,知识可以从复杂的大型模型或由多个模型构成的集成模型中压缩并转移到一个小型模型中,本文基于这一观点做出了进一步研究:通过知识蒸馏(knowledge distillation)显著提高了转移后的小型模型的性能,此外还提出了一种新的集成模型,它由一个或多个完整模型再加多个specialist models(区别是:完整模型无法细粒度分类,specialist模型可以)组成。

名词解释

  • 教师模型 teacher model:一个单个的复杂大型模型 或 由多个模型组成的一个集成模型,知识从教师模型转出;
  • 学生模型 student model:一个小型的、简单的、易于部署的模型,知识转入学生模型;
  • transfer set:包含从教师模型中提取的知识,是学生模型的训练集;
  • transfer 阶段:知识从教师模型转移到学生模型的阶段,即学生模型的训练阶段;
  • soft target:教师模型输出的(每个类的)概率;
  • hard target:原始数据集自带的labels;
  • temperature:蒸馏的目标函数中的超参数,用于控制softmax函数的形状;

1 Introduction

首先训练一个复杂的大型模型/教师模型,蒸馏就是将这个大型模型学到的知识通过特殊的训练转移到小型模型/学生模型上。

作者认为,学生模型应该学习教师模型的泛化能力,而非数据拟合能力;如果教师模型的泛化能力强,学生模型经过学习训练后,就能够在测试集上表现很好。

将教师模型的泛化能力转移到学生模型的一个方法是:使用教师模型产生的(每个类的)概率作为训练学生模型的"soft targets"。在transfer阶段,可以使用相同的训练集或单独的 transfer set,transfer set可以全部由unlabeled的数据构成。当教师模型是集成模型时,可以使用各模型预测的分布的平均值作为soft targets。当soft targets的熵比hard targets高时,它们提供的信息比hard targets多,且训练梯度变化也会变小,所以小模型通常使用更少的数据进行训练,且使用更高的学习率。

对于MNIST这种简单的数据,大型模型的准确率很高,更多的知识是在非常小的软目标中。比如,一个数字2的图像被预测为: \(10^{-6}\) 的概率是数字3,\(10^{-9}\) 的概率是数字7;而对于另一个数字2的图像,预测结果可能相反。这些信息很有价值,它们定义了数据上丰富的相似性结构。但对与transfer阶段的交叉熵损失影响很小,因为概率都十分接近0,对此《模型压缩》使用 logits作为训练小型模型的目标来放大这些信息,而本文采用"蒸馏",方法是提高 softmax 的 temperature 直到大型模型生成一组合适的soft targets,然后用这些soft targets来训练小型模型。

2 Distillation

神经网络通常使用 softmax 层来生成每个类的概率,softmax 层将 logit 层的输出值 \(z_i\) 转换为概率 \(q_i\):

\[q_i = \frac{\exp(z_i/T)} {\sum_j \exp(z_j/T)} \tag{1} \]

其中 \(T\) 为 temperature,通常设置为 1,T 的值越高,生成的概率分布会越平滑(softer)。

transfer sets的两种形式:

  1. 全部由 soft targets 组成,也就是大型模型的 softmax 输出,使用与大型模型相同的 temperature 值;
  2. 由 soft targets 和 hard targets 组成,目标函数选择交叉熵。soft targets 的temperature 值与大型模型的一样,hard targets 的 temperature 值取 1。当使用 soft 和 hard targets,需要乘以 \(T^2\)

2.1 Matching logits is a special case of distillation

大型模型的logits输出为 \(v_i\),输入进softmax层生后计算的概率值为\(p_i\) (也是 soft targets);小型模型的logits值 \(z_i\),softmax层值为\(q_i\)(即公式1)且temperature的值都设置为\(T\),此时的交叉熵为(假设有N个训练数据):

\[C = -\sum_{j=1}^{N} p_j \log q_j \]

transfer 阶段,对\(z_i\)的梯度为:

\[\begin{aligned} \frac{\partial C}{\partial z_i}&=-\sum_{j=1}^{N} p_j \frac{\partial \log q_j}{\partial z_i}\\ &=-\sum_{j=1}^{N} p_j \frac{\partial \log q_j}{\partial q_j}\frac{\partial q_j}{\partial z_i}\\ &=-\sum_{j=1}^{N} p_j \frac{1}{q_j}\frac{\partial q_j}{\partial z_i} \\ \end{aligned} \]

分情况考虑第三项,\(q_j\)的分母部分可以拆分成\(c+e^{z_i/t}\),其中\(c\)\(z_i\)无关,相当于常数。那么,当\(i=j\)时,\(q_j\)可以写成\(q_j = 1 - \frac{c}{c+e^{z_i/T}}\)

\[\begin{aligned} \frac{\partial q_j}{\partial z_i} &= (-c)(-1)\frac{\frac{1}{T}e^{z_i/T}}{(c+e^{z_i/T})^2}\\ &=\frac{1}{T}\frac{e^{z_i/T}}{c+e^{z_i/T}}\frac{c}{c+e^{z_i/T}}\\ &=\frac{1}{T}q_i(1-q_i)\\ &=\frac{1}{T}q_j(1-q_j) \end{aligned} \]

\(i\neq j\)时,\(q_j\)可以写成\(q_j = \frac{e^{z_j/T}}{c+e^{z_i/T}}\)

\[\begin{aligned} \frac{\partial q_j}{\partial z_i} &= (-1)e^{z_j/T}\frac{\frac{1}{T}e^{z_i/T}}{(c+e^{z_i/T})^2}\\ &=-\frac{1}{T}\frac{e^{z_j/T}}{c+e^{z_i/T}}\frac{e^{z_i/T}}{c+e^{z_i/T}}\\ &=-\frac{1}{T}q_j q_i \end{aligned} \]

整理,可得:

\[\begin{aligned} \frac{\partial C}{\partial z_i} &= -\frac{1}{T}\left ( p_j(1-q_j)-\sum_{j=1,j\neq i}^{N}p_j q_i \right )\\ &=\frac{1}{T}\left ( p_j q_j +\sum_{j=1,j\neq i}^{N}p_j q_i -p_j \right )\\ &=\frac{1}{T}\left ( \sum_{j=1}^{N}p_j q_i -p_j \right )\\ &=\frac{1}{T}\left ( q_i - p_j \right )\\ &=\frac{1}{T}\left ( \frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}\right ) \end{aligned} \tag{2} \]

如果 temperature 的值高于 logits,即 \(T \gg z_i,v_i\),则可以近似:

\[\frac{\partial C}{\partial z_i} \approx \frac{1}{T}\left (\frac{1+z_i/T}{N+\sum_j z_j/T} - \frac{1+v_i/T}{N+\sum_j v_j/T}\right ) \tag{3} \]

假设 logits 的均值为0,即 \(\sum_j z_j = \sum_j v_j = 0\),则有:

\[\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i-v_i) \tag{4} \]

因此,在temperature很高的极限下和每个transfer case都为零均值时,蒸馏等同于最小化 \(1/2(z_i - v_i)^2\),即MSE。

当 temperature较低时,蒸馏不怎么关注极小的负的logits,这既有优点又有缺点:好处是这些logits可能非常嘈杂,坏处是logits可能会传递一些有用的信息。所以temperature往往取中间值时效果好。

3&4 实验

作者在图像识别和语音识别两个领域进行实验,不过多描述。

5 Training ensembles of specialists on very big datasets

在这节中,针对JFT数据集,训练了多个能够细粒度分类的specialist models 和通用模型,将这些模型组合成集成模型。其中,specialist models很容易过拟合,作者还给出了如何防止过拟合的方法。

5.1 The JFT dataset

JFT 是一个谷歌数据集,包含 1 亿张带有 15,000 种标签的图像。

5.2 Specialist Models

各个specialist models 会在各个类的集合上进行训练,比如全是蘑菇但不同种类的蘑菇,将它们不关心的类整合为一个dustbin class,这样它们会给出很小的softmax值。

为了减少过拟合并分担通用模型的工作,每个specialist model都使用通用模型的权重进行初始化。然后通过训练specialist models来稍微修改这些权重,其中一半样本来自其特殊子集,一半来自训练集的其余部分随机抽样。训练后,可以通过将dustbin class的 logit 增加 log(specialist class 被采样的比例) 来纠正有偏差的训练集。

5.3 Assigning classes to specialists

为了为specialists派生类别的分组,作者重点关注于常混淆的类别。作者将聚类算法应用于通用模型预测的协方差矩阵,经常被一起预测的一组类 \(S^m\) 将用作一个specialist model 的target \(m\)。表 2 采用的是on-line version of the K-means。

表1. 由协方差矩阵聚类算法计算的聚类类

5.4 Performing inference with ensembles of specialists

首先检查这个新的集成模型效果如何,对于一个给定的图像\(x\),分两步进行分类:
第 1 步:对于每个测试样本,根据通用模型找到 \(n\) 个最可能的类,称这组类为 \(k\),在实验中作者使用 \(n = 1\)
第 2 步:然后取所有的满足以下的specialists \(m\):其可混淆类的特殊子集 \(S^m\)\(k\) 有一个非空交集,并将其称为specialists 的active set \(A_k\)(该集合可能为空)。 然后找所有类的完整概率分布 \(\mathbf{q}\)\(\mathbf{q}\)最小化以下公式:

\[KL(\mathbf{p}^g,\mathbf{q})+\sum_{m\in A_k} KL(\mathbf{p}^m,\mathbf{q}) \]

\(\mathbf{p}^m,\mathbf{p}^g\) 表示specialists或完整模型的概率分布,\(\mathbf{p}^m\)\(m\) 的所有specialist类别加上单个dustbin class的分布。

6 Soft Targets as Regularizers

作者认为使用soft targets能够防止specialists过拟合,起到了正则化的作用。

posted @ 2022-06-22 13:05  李斯赛特  阅读(1199)  评论(0编辑  收藏  举报