从零开始知识蒸馏(自用版)

一、什么是知识蒸馏,为什么要使用知识蒸馏?

知识蒸馏是一种将预先训练的教师网络的知识转移到学生网络的方法,使小型网络可以在网络部署阶段取代大型教师网络。知识蒸馏的概念最初是由Hinton等人提出的,已广泛应用于各个领域和任务。知识蒸馏的基本原理是通过训练一个更小、更轻量级的模型来学习更大、更复杂的模型的知识。通常,复杂模型被称为“教师模型”,而简化模型被称为“学生模型”。教师模型可以是深度神经网络或其他复杂模型,而学生模型通常是较浅或较窄的层神经网络。通过将教师模型的输出及其对应的标签作为学生模型的训练目标,学生模型可以从教师模型中获得更多的知识,并在学习过程中逐渐接近或超过教师模型的性能。

二、知识是什么?

首先,区分硬标签和软标签,硬标签就是对分类结果,1就是1,0就是0,一只猫判断它是猫的概率是1,是狗的概率是0,软标签就是用概率给它一个不那么确定的标签,一只猫判断它是猫的概率是0.8,是狗的概率是0.2。

硬标签是我们数据集中通常已知的,一个模型经过训练后它输出的往往是软标签,软标签比硬标签具有更多的知识,比如图片猫的概率是0.8,狗的概率是0.2,说明猫和狗在一定程度上有相似性,而和苹果的相似性为0,这给了我们类别之间更多的关联和信息。

因此,小模型除了利用已知的硬标签,还可以从大模型给的预测软标签中学习更多的“知识”。

三、如何蒸馏知识

Ne-S既要学习真实标签,也就是硬标签,还要学习Teacher给的软标签,那么损失函数就定义为:
L=CE(y,p)+αCE(q,p),y是真实标签,p是Student的预测,q是Teacher的预测。

此外,由于softmax通常把不同类的预测概率区分的很大,比如猫的是0.999,狗是0.001,苹果是0,这样狗和苹果和猫的相似度几乎都一样为0了,为了避免这种情况,加入温度Temperature,让每个类的预测差距不那么大。

 

这样更有利于Student学习到知识。

原来的softmax函数是T = 1的特例。 T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

温度代表了什么,如何选取合适的温度?

 

温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。

 

实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:

 

  1. 从有部分信息量的负标签中学习 --> 温度要高一些
  2. 防止受负标签中噪声的影响 -->温度要低一些

 

总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)

 四、知识蒸馏方法

知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏)和基于特征蒸馏的算法两个大的方向。

 

目标蒸馏-Logits方法

 

 

分类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的Teacher模型,我们在利用Teacher模型来蒸馏训练Student模型时,可以直接让Student模型去学习Teacher模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“Soft-target” 。

(1)在使用Soft-target训练时,Student模型可以很快学习到Teacher模型的推理过程。

(2)传统的Hard-target的训练方式,所有的负标签都会被平等对待。Soft-target给Student模型带来的信息量要大于Hard-target,并且Soft-target分布的熵相对高时,其Soft-target蕴含的知识就更丰富。

(3)使用Soft-target训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。

特征蒸馏方法

它不像Logits方法那样,Student只学习Teacher的Logits这种结果知识,而是学习Teacher网络结构中的中间层特征。它强迫Student某些中间层的网络响应,要去逼近Teacher对应的中间层的网络响应。这种情况下,Teacher中间特征层的响应,就是传递给Student的知识,本质是Teacher将特征级知识迁移给Student。

五、蒸馏损失计算过程

上部分教师网络,它进行预测的时候, softmax要进行升温,升温后的预测结果我们称为软标签(soft label)
学生网络一个分支softmax的时候也进行升温,在预测的时候得到软预测(soft predictions),然后对soft label和soft predictions 计算损失函数,称为distillation loss ,让学生网络的预测结果接近教师网络;
学生网络的另一个分支,在softmax的时候不进行升温T =1,此时预测的结果叫做hard prediction 。然后和hard label也就是 ground truth直接计算损失,称为student loss 。
总的损失结合了distilation loss和student loss ,并通过系数a加权,来平衡这两种Loss ,比如与教师网络通过MSE损失,学生网络与ground truth通过cross entropy损失, Loss的公式可表示如下:

 

posted @ 2024-01-24 14:00  ninisong  阅读(642)  评论(0)    收藏  举报