知识蒸馏原理及BERT蒸馏实战
------------恢复内容开始------------
------------恢复内容开始------------
前言
本文主要介绍知识蒸馏原理,并以BERT为例,介绍两篇BERT蒸馏论文及代码,第一篇论文是在下游任务中使用BiLSTM对BERT蒸馏,第二篇是对Transformer蒸馏,即TinyBert。
知识蒸馏
https://arxiv.org/pdf/1503.02531.pdf
由于大模型参数量巨大,线上部署不仅对机器资源要求比较高而且推理速度慢,因此需要对模型进行压缩加速,知识蒸馏便是模型压缩的一种形式。
知识蒸馏(Knowledge Distillation)基于“教师-学生网络”思想,将已经训练好的大模型(教师)中的知识迁移到小模型(学生)训练中。
知识蒸馏分为两步:
-
在数据集上训练大模型(教师)
-
在高温T下,对大模型进行蒸馏,将大模型学习到的知识迁移到小模型(学生)上
一般分类任务处理逻辑是通过softmax层将输出层输出的logits转化成概率分类,然后计算预估概率与真实标签的交叉熵作为损失进行梯度更新。
知识蒸馏是希望让小模型能够学到大模型的输出,为什么是输出呢?
因为真实标签是one-hot形式表示,计算预估概率与真实标签的交叉熵时无法学习到其他类目的知识,通过让小模型拟合大模型的输出,比单纯拟合真实标签能学到更多的知识。
大模型的输出有两种,分别是logits和经softmax层后概率,下面将分别介绍蒸馏中这两种输出的拟合方式。
拟合softmax
softmax层后得到的是各类目的概率分布,由于使用指数函数会放大logits,使类目的概率差异变大,知识蒸馏时使用温度(T)对logits放缩,从而使softmax后的概率分布不要有太大的差异,即能学到原始类目间关系。
为学生网络(Net-S)在相同高温(T)下经softmax后产出的概率分布与教师网络(Net-T)输出(soft target)的交叉熵,即 ,其中 是教师网络输出, , q_{i} 是学生网络输出, 。
学生网络(Net-S)在经softmax后产出的概率分布(不用高温)与真实标签(hard target)的交叉熵组合,即 ,其中 是真实标签。
拟合logits
与拟合softmax层相比,这种方式较简单,最小化的目标函数是教师网络和学生网络输出logits的平方差,即
关于温度的理解
温度影响softmax层的输出,当T比较大时, 每个类的输出概率会比较接近,这样能学习到能过其他类目的信息。
温度高低代表对负标签的关注程度,温度越高,负标签的值相对较大,学生网络能学习到更多负标签信息。
Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
背景
BERT模型在下游任务fine-tuning后,由于参数量巨大,计算比较耗时,很难真正上线使用,该论文提出使用简单神经网络(单层BiLSTM)对fine-tuned BERT进行蒸馏,蒸馏后的BiLSTM模型与ELMo效果相同,但是参数量减少100倍且推理时间减少15倍。
模型结构
以在训练集上fine-tune后的BERT模型作为teacher网络,BiLSTM作为student网络进行蒸馏训练,整体训练过程如下:
-
先用fine-tuning后的bert对训练数据进行预估,得到bert输出概率
-
然后使用BiLSTM网络对训练数据进行建模,得到BiLSTM输出概率
-
最后计算hard loss(BiLSTM输出概率分布与真实标签的交叉熵)和soft loss(BiLSTM与Bert输出logits的均方误差),加权作为损失
使用BiLSTM进行分类的结构如下,使用BiLSTM(b)对序列(a)进行学习,将前向(c)和后向(d)最后隐层向量拼接后连接带有relu激活函数的全连接层(efg)得到logit输出(h),再经softmax(i)得到概率分布(j)。
损失函数如下:
模型效果
蒸馏后的BiLSTM在GLUE语料上的效果均优于普通的BiLSTM,在SST-2和QQP任务上效果与ELMo类似。
https://github.com/xiaopp123/knowledge_distillation
TinyBERT: Distilling BERT for Natural Language Understanding
背景
为提高bert的推理和计算性能,论文提出使用Transformer蒸馏方式将Bert蒸馏至TinyBert,另外,论文还提出两阶段的学习框架,即预训练阶段和fine-tuning阶段都对Bert蒸馏。蒸馏后的TinyBert在GLUE任务集上能达到原始Bert的96.8%,模型大小比原来减少到7.5倍,推理性能提高到9.4倍。
模型结构
Transformer包含两部分:MHA(多头注意力层)和FFN(前馈神经网络)。如图所示,Transformer蒸馏方式正是基于MHA和FFN隐藏状态进行蒸馏的。
论文计算教师网络和学生网络输出logits的交叉熵作为输出层蒸馏损失函数
综上,模型的蒸馏函数为:
TinyBert学习过程分为两步:General Distillation和Task-specific Distillation。
Generation Distillation是指预训练阶段蒸馏,这部分使用的是通用数据集故称为General Distillation。
预训练阶段训练的TinyBert由于参数较少,与原始Bert相比在下游任务中的效果必然有损,因此论文提出针对下游任务的Task-specific Distillation,该过程以原始Bert作为教师模型,TinyBert作为学生模型在特定数据集上进行蒸馏学习。
实现代码
在下文fine-tuning任务,分两步进行训练,第一步是蒸馏Transformer,第二步是蒸馏下游任务输出层
Transormer蒸馏
# Transformer蒸馏 # 教师网络层数大于学生网络 teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) # attention层蒸馏 # 学生网络第i层学习教师网络第i * layers_per_block + layers_per_block - 1层 # 若学生网络是3,教师网络为12,则第0层学习第3层,第1层学第7层 new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] for student_att, teacher_att in zip(student_atts, new_teacher_atts): student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) # attention层蒸馏损失为均方误差 tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss # 前馈神经网络层和Embedding层蒸馏 # 学生第0层学习教师第0层,第0层是embedding层输出 # 第i层学习第layers_per_block * i层 new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)] new_student_reps = student_reps # 前馈神经网络层和Embedding层蒸馏均方误差 for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss = rep_loss + att_loss
输出层蒸馏
# 输出层蒸馏 # 分类任务是教师网络和学生网络输出logits交叉熵 if output_mode == "classification": cls_loss = soft_cross_entropy(student_logits / args.temperature, teacher_logits / args.temperature) elif output_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss
这里重点讲一下如何针对具体下游任务进行fine-tuning:
数据准备
这里可以是自己的数据集,也可以是GLUE任务。
预训练模型
需要下载Bert预训练模型和TinyBert预训练模型。
Bert预训练模型在HuggingFace官网“Model”模块输入bert,找到适合自己的bert预训练模型,在“Files and versions”选择自己需要模型和文件下载,目前好像只能一个一个文件下载。
Transformer蒸馏
python task_distill.py --teacher_model ${FT_BERT_BASE_DIR}$ \ --student_model ${GENERAL_TINYBERT_DIR}$ \ --data_dir ${TASK_DIR}$ \ --task_name ${TASK_NAME}$ \ --output_dir ${TMP_TINYBERT_DIR}$ \ --max_seq_length 128 \ --train_batch_size 32 \ --num_train_epochs 10 \ --aug_train \ --do_lower_case
输出层蒸馏
python task_distill.py --pred_distill \ --teacher_model ${FT_BERT_BASE_DIR}$ \ --student_model ${TMP_TINYBERT_DIR}$ \ --data_dir ${TASK_DIR}$ \ --task_name ${TASK_NAME}$ \ --output_dir ${TINYBERT_DIR}$ \ --aug_train \ --do_lower_case \ --learning_rate 3e-5 \ --num_train_epochs 3 \ --eval_step 100 \ --max_seq_length 128 \ --train_batch_size 32
参考
-
https://github.com/airaria/TextBrewer
-
https://github.com/qiangsiwei/bert_distill
-
https://towardsdatascience.com/simple-tutorial-for-distilling-bert-99883894e90a