知识蒸馏原理及BERT蒸馏实战

------------恢复内容开始------------

------------恢复内容开始------------

首发于https://zhuanlan.zhihu.com/p/503739300

前言

本文主要介绍知识蒸馏原理,并以BERT为例,介绍两篇BERT蒸馏论文及代码,第一篇论文是在下游任务中使用BiLSTM对BERT蒸馏,第二篇是对Transformer蒸馏,即TinyBert。

知识蒸馏

https://arxiv.org/pdf/1503.02531.pdf​arxiv.org
由于大模型参数量巨大,线上部署不仅对机器资源要求比较高而且推理速度慢,因此需要对模型进行压缩加速,知识蒸馏便是模型压缩的一种形式。
知识蒸馏(Knowledge Distillation)基于“教师-学生网络”思想,将已经训练好的大模型(教师)中的知识迁移到小模型(学生)训练中。
知识蒸馏分为两步:
  1. 在数据集上训练大模型(教师)
  2. 在高温T下,对大模型进行蒸馏,将大模型学习到的知识迁移到小模型(学生)上
下面介绍知识蒸馏在分类任务中的做法:
一般分类任务处理逻辑是通过softmax层将输出层输出的logits转化成概率分类,然后计算预估概率与真实标签的交叉熵作为损失进行梯度更新。
知识蒸馏是希望让小模型能够学到大模型的输出,为什么是输出呢?
因为真实标签是one-hot形式表示,计算预估概率与真实标签的交叉熵时无法学习到其他类目的知识,通过让小模型拟合大模型的输出,比单纯拟合真实标签能学到更多的知识。
大模型的输出有两种,分别是logits和经softmax层后概率,下面将分别介绍蒸馏中这两种输出的拟合方式。
拟合softmax
softmax层后得到的是各类目的概率分布,由于使用指数函数会放大logits,使类目的概率差异变大,知识蒸馏时使用温度(T)对logits放缩,从而使softmax后的概率分布不要有太大的差异,即能学到原始类目间关系。
高温蒸馏的损失函数为 L=\alpha L_{soft} + \beta L_{hard}
L_{soft}为学生网络(Net-S)在相同高温(T)下经softmax后产出的概率分布与教师网络(Net-T)输出(soft target)的交叉熵,即 L_{soft} = -\textstyle\sum_{i=1}^n p_{i}log(q_{i}) ,其中 p_{i}是教师网络输出, p_{i} = \frac{exp(v_{i}/T)}{\textstyle\sum_{k}^n exp(v_{k}/T)} , q_{i}q_{i} 是学生网络输出, q_{i} = \frac{exp(z_{i}/T)}{\textstyle\sum_{k}^n exp(z_{k}/T)} 。
L_{hard} 学生网络(Net-S)在经softmax后产出的概率分布(不用高温)与真实标签(hard target)的交叉熵组合,即 L_{hard} = -\textstyle\sum_{i=1}^n c_{i}log(q_{i})  ,其中 c_{i}  是真实标签。
拟合logits
与拟合softmax层相比,这种方式较简单,最小化的目标函数是教师网络和学生网络输出logits的平方差,即 L_{soft} = \textstyle\sum_{i=1}^n (v_{i} - z_{i})^{2}
关于温度的理解
温度影响softmax层的输出,当T比较大时, 每个类的输出概率会比较接近,这样能学习到能过其他类目的信息。
温度高低代表对负标签的关注程度,温度越高,负标签的值相对较大,学生网络能学习到更多负标签信息。

Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

背景

BERT模型在下游任务fine-tuning后,由于参数量巨大,计算比较耗时,很难真正上线使用,该论文提出使用简单神经网络(单层BiLSTM)对fine-tuned BERT进行蒸馏,蒸馏后的BiLSTM模型与ELMo效果相同,但是参数量减少100倍且推理时间减少15倍。

模型结构

以在训练集上fine-tune后的BERT模型作为teacher网络,BiLSTM作为student网络进行蒸馏训练,整体训练过程如下:
  1. 先用fine-tuning后的bert对训练数据进行预估,得到bert输出概率
  2. 然后使用BiLSTM网络对训练数据进行建模,得到BiLSTM输出概率
  3. 最后计算hard loss(BiLSTM输出概率分布与真实标签的交叉熵)和soft loss(BiLSTM与Bert输出logits的均方误差),加权作为损失
使用BiLSTM进行分类的结构如下,使用BiLSTM(b)对序列(a)进行学习,将前向(c)和后向(d)最后隐层向量拼接后连接带有relu激活函数的全连接层(efg)得到logit输出(h),再经softmax(i)得到概率分布(j)。
与原始BiLSTM相比,蒸馏bert是用fine-tuned Bert对文本(a)的输出logits与BiLSTM的学习的logits(h)做均方误差,使Bert能将知识转移给BiLSTM,也就是BiLSTM的输出与bert的输出接近,这便是teacher与student的含义。
损失函数如下: \begin{aligned}  L &= \alpha L_{CE} + (1-\alpha) L_{MSE} \\   &= - \alpha \textstyle\sum_{i=1}^n t_{i}log{y_{i}^{(S)}} + (1-\alpha) ||z^{(T)}-z^{(S)}||_{2}^{2} \end{aligned}

模型效果

蒸馏后的BiLSTM在GLUE语料上的效果均优于普通的BiLSTM,在SST-2和QQP任务上效果与ELMo类似。

实现代码

https://github.com/xiaopp123/knowledge_distillation

TinyBERT: Distilling BERT for Natural Language Understanding

背景

为提高bert的推理和计算性能,论文提出使用Transformer蒸馏方式将Bert蒸馏至TinyBert,另外,论文还提出两阶段的学习框架,即预训练阶段和fine-tuning阶段都对Bert蒸馏。蒸馏后的TinyBert在GLUE任务集上能达到原始Bert的96.8%,模型大小比原来减少到7.5倍,推理性能提高到9.4倍。

模型结构

Transformer包含两部分:MHA(多头注意力层)和FFN(前馈神经网络)。如图所示,Transformer蒸馏方式正是基于MHA和FFN隐藏状态进行蒸馏的。
attention层蒸馏的损失函数如下,其中 h表示head number,A_{i}表示第 i i 个head的注意力矩阵,维度为 l*ll表示序列长度,这个损失函数的作用是使学生模型能学到教师模型中的注意力矩阵。
FNN隐藏蒸馏的损失函数如下,其中H^{T}表示教师模型FNN网络的输出,维度为 l*dH^{S}表示学生模型FNN网络的输出,维度为 l*d',一般情况下, d>d' ,W_{h}  表示映射矩阵,维度为 d'*d ,即将学生网络输出映射到教师网络输出向量维度上。
与FNN层蒸馏方式相同,论文对embedding层也进行蒸馏,损失函数如下所示:
输出层蒸馏
论文计算教师网络和学生网络输出logits的交叉熵作为输出层蒸馏损失函数
输出层蒸馏损失函数
综上,模型的蒸馏函数为:

训练过程

TinyBert学习过程分为两步:General Distillation和Task-specific Distillation。
Generation Distillation是指预训练阶段蒸馏,这部分使用的是通用数据集故称为General Distillation。
预训练阶段训练的TinyBert由于参数较少,与原始Bert相比在下游任务中的效果必然有损,因此论文提出针对下游任务的Task-specific Distillation,该过程以原始Bert作为教师模型,TinyBert作为学生模型在特定数据集上进行蒸馏学习。

实现代码

在下文fine-tuning任务,分两步进行训练,第一步是蒸馏Transformer,第二步是蒸馏下游任务输出层

Transormer蒸馏

# Transformer蒸馏
# 教师网络层数大于学生网络
teacher_layer_num = len(teacher_atts)
student_layer_num = len(student_atts)
assert teacher_layer_num % student_layer_num == 0
layers_per_block = int(teacher_layer_num / student_layer_num)
# attention层蒸馏
# 学生网络第i层学习教师网络第i * layers_per_block + layers_per_block - 1层
# 若学生网络是3,教师网络为12,则第0层学习第3层,第1层学第7层
new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                    for i in range(student_layer_num)]

for student_att, teacher_att in zip(student_atts, new_teacher_atts):
    student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
                              student_att)
    teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
                              teacher_att)
    # attention层蒸馏损失为均方误差
    tmp_loss = loss_mse(student_att, teacher_att)
    att_loss += tmp_loss
# 前馈神经网络层和Embedding层蒸馏
# 学生第0层学习教师第0层,第0层是embedding层输出
# 第i层学习第layers_per_block * i层
new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
new_student_reps = student_reps
# 前馈神经网络层和Embedding层蒸馏均方误差
for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
    tmp_loss = loss_mse(student_rep, teacher_rep)
    rep_loss += tmp_loss

loss = rep_loss + att_loss

输出层蒸馏

# 输出层蒸馏
# 分类任务是教师网络和学生网络输出logits交叉熵
if output_mode == "classification":
    cls_loss = soft_cross_entropy(student_logits / args.temperature,
                                  teacher_logits / args.temperature)
elif output_mode == "regression":
    loss_mse = MSELoss()
    cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1))

loss = cls_loss

这里重点讲一下如何针对具体下游任务进行fine-tuning:

数据准备
这里可以是自己的数据集,也可以是GLUE任务。
预训练模型
需要下载Bert预训练模型和TinyBert预训练模型。
Bert预训练模型在HuggingFace官网“Model”模块输入bert,找到适合自己的bert预训练模型,在“Files and versions”选择自己需要模型和文件下载,目前好像只能一个一个文件下载。
TinyBert预训练模型:huawei-noah (HUAWEI Noah's Ark Lab)
Transformer蒸馏
python task_distill.py --teacher_model ${FT_BERT_BASE_DIR}$ \
                       --student_model ${GENERAL_TINYBERT_DIR}$ \
                       --data_dir ${TASK_DIR}$ \
                       --task_name ${TASK_NAME}$ \ 
                       --output_dir ${TMP_TINYBERT_DIR}$ \
                       --max_seq_length 128 \
                       --train_batch_size 32 \
                       --num_train_epochs 10 \
                       --aug_train \
                       --do_lower_case  
输出层蒸馏
python task_distill.py --pred_distill  \
                       --teacher_model ${FT_BERT_BASE_DIR}$ \
                       --student_model ${TMP_TINYBERT_DIR}$ \
                       --data_dir ${TASK_DIR}$ \
                       --task_name ${TASK_NAME}$ \
                       --output_dir ${TINYBERT_DIR}$ \
                       --aug_train  \  
                       --do_lower_case \
                       --learning_rate 3e-5  \
                       --num_train_epochs  3  \
                       --eval_step 100 \
                       --max_seq_length 128 \
                       --train_batch_size 32 

参考

 

posted @ 2022-04-23 21:44  zhaop  阅读(1261)  评论(0编辑  收藏  举报