LLM大模型:推理优化-知识蒸馏
1、有些模型比较大,推理时的效果还不错,但非常耗费计算资源;并且产生token的速度也很慢,大概1秒1个token(我的RAG在最后一步使用的secGPT-13B大概就是这个速度),一个问题回答完毕要耗费分钟级别的时间,用户直接抓狂,继续提升推理的速度!
大模型本质是大量的矩阵运算,想要提高效率,就要想办法提升矩阵运算的效率,大致的思路如下:
- 知识蒸馏distillation:大模型去掉“水分”,保留“精华”后得到小模型
- 模型剪枝:矩阵中某些元素毫无卵用,留着纯属“占着茅坑不拉屎”
- 模型量化:FP32、FP16用INT8、INT4替代,减少存储和计算
-
参数共享:部分层级之间共享参数,减少存储空间,提升计算效率
-
低秩分解:原理类同Lora,把大矩阵分解成low -rank 小矩阵,减少存储空间,提升计算效率
- 参数搜索:使用算法或启发式方法来确定最佳的参数配置
这么多方法,相比之下知识蒸馏是比较流行的,效果也是比较好的,这里尝试一下对secGPT-13B做做知识蒸馏(他家已经有secGPT-mini了,具体怎么的来的还不清楚);
2、不论是现在的LLM,还是传统的机器学习,最终的目的都是提升泛化性能,提高鲁棒性,让模型经过训练后,在新的数据上也能有很好的表现。同理,知识蒸馏的最终目的也是让student在新数据集上的表现接近teacher,该怎么去模仿学习teacher了?
所有的神经网络简化图如上,让student的输出逼近teacher,有这么三种方式:
- 直接让student的output接近teacher,其他的不care(只看最终的结果,不管中间过程),这就是所谓的response-based knowledge
- 为了更好地让student逼近teacher,只看结果可能还不够,还要严控过程,让hidden layer的效果也逼近teacher,这就是feather-based knowledge
- 再进一步,融合了前面两项,再加上input layer,对整个全流程(input->hidden-output)做系统性地模仿学习,捕捉样本之间的关系和teacher模型地全局结构信息,叫relation-based knowledge;
第一种response-based的方式最简单,不需要考虑student和teacher的网络结构是否相同,只看结果两个模型的输出loss,根据loss反向调整sudent的参数即可,所以完全可以使用现成的模型作为student继续fine-tune!我搜寻了一遍小模型,知名度高、对中文支持又比较好的:gpt2-chinese-cluecorpussmall(1.2亿参数)、gpt2-distil-chinese-cluecorpussmall、t5-base-chinese-cluecorpussmall(2.4亿参数),这里最终选择gpt2-distil-chinese-cluecorpussmall(从名称看,可能已经蒸馏过了,应该验证了蒸馏效果还行)作为student,和secGPT-13B组成cp做distillation!确定好模型后,接下来就是怎么实操落地啦!换句话说,怎么让teacher把所有的knowleage都准确无误、毫无保留地传授给student了?
3、大家回忆一下自己小时候上学的场景:坐在教室里,有各种教材课程,然后听老师讲课。课后自己做作业,老师批改作业,做错的题还要重新做,直到做对为止,整个流程经年累月后自己可以从老师那学到大量的knowledge,这一整个流程在LLM的knowledge distil中能不能被借鉴了?同样的训练数据,分别经过student和teacher模型做前向传播计算,然后对比双方的输出,差异作为loss,student根据这个loss调整自己的参数,直到loss变小为止,原理是不是很简单?接着的问题又来了:
- teacher和sutdent之间怎么计算loss?换句话说loss函数怎么设计?
- 既然都有训练数据了,为什么不直接用这些数据fine-tune student模型?为什么还要用teacher去训练student?
GPT模型decoder部分最后一步都是softmax,输出vocab中每个token的概率值。传统的training过程是训练语料中token作为one-hot形式的ground truth,让GPT的softmax输出和ground truth计算KL散度的差异,用这个差异做BP调整模型的参数。由于ground truth的token都是one-hot形式,也就是当前token的概率是1,其他token的概率是0,所以这种目标称为hard targets,流行的teacher模型本身就是通过这种hard target训练出来的!
问题是teacher模型decoder的输出是所有token的概率组成的向量,比如[0.1,0.6,0.05,0.15,0.04,0.06]这种,不是one-hot的hard targets,这种soft targets能被用于训练student么?为啥不直接用原始训练数据的hard targets去微调student了?
还是以小时候上学读书为例:其实各种教材资料在市场上自己都能买到,为啥每天还要辛苦跑去学校读书了?为啥不自己在家里自学了?核心原因之一:教材内容展示的知识有限,有很多隐藏的知识是教材纸面上无法展示的!比如英语的发音,教材只能标识音标,具体怎么发音还是要靠老师教授和纠正;又比如数学定理的推导:有些教材推理过程并不详细,自学的时候可能看不懂为什么会从某些条件得到某些结论,期间还是要经验丰富的老师具体细化整个推导过程!总结一下,就是各种教材里面承载的明面知识有限,还有很多隐藏知识(dark knowledge)需要老师教授的!具体到LLM的知识蒸馏和训练,举例如下:
上面这两图的数值,是2还是3了?是2还是7了?这就是hard target和soft target的本质区别!如果直接使用原始的数据训练student,用的就是hard target;如果使用teacher训练student,用的就是soft target!最核心的问题来了:为啥要用soft target训练student?soft target相比hard target,优势在哪?
仔细看上图,左边的数字确实是2,但是也有3的特征!右边数字确实是2,但也有7的特征,所以这两图也包含了其他数字的特征,所以如果直接简单粗暴地用hard target指定为2,那么3和7的特征是学不到的!teacher训练时用了大量的语料,其他语料也有3和7的特征,所以这里用hard target也没啥问题;但知识蒸馏的场景下,训练语料是有限的,如果用hard target,student是无法提取3和7的特征的,会严重影响其他类别的判断!所以使用soft target最大的好处:
soft-target 指明类别之间的相对关系,可以让student学到其他类别的特征,大大提升模型的泛化性和鲁棒性!这不就是所有机器学习的终极目标么?
4、(1)确定了使用soft target后,就要确定loss的具体表达式了。为了提升泛化性和鲁棒性,对于负类不能像one-hot编码那样“赶尽杀绝”,需要适当给予一些概率,利于student模型提取特征,具体操作方式如下:
T参数全名tempareture,用来调节概率的平滑性的。直观感受tempareture参数的作用:以logits = [-1,1,3,2,0.5]为例,不同的temperature对应不同的class probability,图示如下:
看吧,T越大,各个不同类别的概率越接近,概率分布越平滑!
(2)整个知识蒸馏的全流程:
- teacher模型对input做feed forward计算,得到的结果经过softmax(t)后得到soft labels;
- student模型同样对input做feed forward计算,然后分叉:
- 和teacher一样,得到的结果经过softmax(t)后得到soft predictions;
- 设置T=1,和原来的softmax效果一样,得到hard predictions;
- soft labels和soft predictions,用于衡量teacher和student之间的差异
- hard prediction和hard label,用于衡量student和ground truth之间的差异
问题又特么来了:为啥要计算两个loss?这两个loss之间怎么取舍?
- teacher虽然学识远远超过student,但是仍然有出错的可能,而这时候如果student在teacher的教授之外,可以同时参考到标准答案,就可以有效地降低被teacher偶尔“带偏”的可能性。
- 既然又两个loss,那就人为分别设置权重呗,重要的loss权重高点,另一个权重低点!两个loss的权重是超参数,可以自由设置;
(3)核心代码如下:
import json from datasets import Dataset from transformers import GPT2LMHeadModel, AutoModelForCausalLM, BertTokenizer import torch from transformers import Trainer, TrainingArguments # 加载数据 data_path = "/root/huggingface/data/" data_files = ["distil.json"] data = [] for file in data_files: with open(data_path + file, 'r', encoding='utf-8') as f: for line in f: data.append(json.loads(line)) # 将数据转换为 Hugging Face Dataset 格式 dataset = Dataset.from_list(data) # 加载tokenizer tokenizer = BertTokenizer.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall") # 添加pad_token tokenizer.add_special_tokens({'pad_token': '[PAD]'}) def preprocess_function(examples): inputs = tokenizer(examples["query"], truncation=True, padding="max_length", max_length=512) outputs = tokenizer(examples["positive"], truncation=True, padding="max_length", max_length=512) return { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "labels": outputs["input_ids"], "labels_attention_mask": outputs["attention_mask"] } tokenized_dataset = dataset.map(preprocess_function, batched=True) split_datasets = tokenized_dataset.train_test_split(test_size=0.3) train_dataset = split_datasets['train'] eval_dataset = split_datasets['test'] teacher_model = AutoModelForCausalLM.from_pretrained("/root/huggingface/secgpt", trust_remote_code=True, device_map="cpu") student_model = GPT2LMHeadModel.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall") # 确保教师模型和学生模型的词汇表大小一致 student_model.resize_token_embeddings(len(tokenizer)) teacher_model.resize_token_embeddings(len(tokenizer)) training_args = TrainingArguments( output_dir="/root/huggingface/SecGPT_distil", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=1, num_train_epochs=10, weight_decay=0.01, logging_dir='./logs', ) class DistillationTrainer(Trainer): def __init__(self, teacher_model, *args, **kwargs): super().__init__(*args, **kwargs) self.teacher_model = teacher_model def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) with torch.no_grad(): teacher_outputs = self.teacher_model(**inputs) student_logits = outputs.logits teacher_logits = teacher_outputs.logits.detach() # 确保 student_logits 和 teacher_logits 的形状一致 if student_logits.shape != teacher_logits.shape: raise ValueError( f"Student logits shape {student_logits.shape} does not match teacher logits shape {teacher_logits.shape}") loss_fct = torch.nn.KLDivLoss(reduction="batchmean") temperature = 2.0 alpha = 0.5 beta = 0.5 student_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1) teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) distillation_loss = loss_fct(student_probs, teacher_probs) * (temperature ** 2) # 计算student loss labels = inputs["labels"] shift_logits = student_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct_student = torch.nn.CrossEntropyLoss() student_loss = loss_fct_student(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # 结合loss loss = alpha * distillation_loss + beta * student_loss return (loss, outputs) if return_outputs else loss trainer = DistillationTrainer( model=student_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, teacher_model=teacher_model, ) trainer.train() student_model.save_pretrained("/root/huggingface/SecGPT_distil") tokenizer.save_pretrained("/root/huggingface/SecGPT_distil")
训练样本的数据格式:
{"query": "frida是什么?", "positive": "Frida是一款基于python + javascript 的hook框架,适用于android/ios/linux/win/osx等平台。Frida的动态代码执行功能,主要是在它的核心引擎Gum中用C语言来实现的"} {"query": "怎么使用IDA?", "positive": "1、安装IDA 2、用IDA打开二进制文件,可以使用F5将汇编反编译成C语言伪代码 3、可以直接调试伪代码了解二进制代码逻辑"} {"query": "怎么脱壳?", "positive": "对于一代、二代壳,可以直接使用frida dexdump从内存把正常的dex代码dump到磁盘"}
我这里的样本少,明显不够,效果不也好,后续还要继续努力收集数据啊.......
(4) https://techdiylife.github.io/blog/topic.html?category2=t05&blogid=0031 这里有模型推理和训练需要耗费显存的现成方案:直接用accelerate评估。命令如下:
accelerate estimate-memory baichuan-inc/Baichuan-13B-Base --trust-remote-code
列举了4种不同量化单位的耗费显存大小:
参考:
1、https://blog.csdn.net/qq_52572775/article/details/138467295?spm=1001.2014.3001.5501 知识蒸馏Knowledge Distillation
2、https://www.jiqizhixin.com/articles/2024-03-18 LLM知识蒸馏最新综述
3、https://zhuanlan.zhihu.com/p/102038521 知识蒸馏经典之作
4、https://blog.csdn.net/jclian91/article/details/133896540 使用知识蒸馏提升模型推理性能
5、https://intellabs.github.io/distiller/knowledge_distillation.html Knowledge Distillation
6、常见推理优化技术方案: