LLM大模型:推理优化-知识蒸馏

    1、有些模型比较大,推理时的效果还不错,但非常耗费计算资源;并且产生token的速度也很慢,大概1秒1个token(我的RAG在最后一步使用的secGPT-13B大概就是这个速度),一个问题回答完毕要耗费分钟级别的时间,用户直接抓狂,继续提升推理的速度!

  大模型本质是大量的矩阵运算,想要提高效率,就要想办法提升矩阵运算的效率,大致的思路如下:

  • 知识蒸馏distillation:大模型去掉“水分”,保留“精华”后得到小模型
  • 模型剪枝:矩阵中某些元素毫无卵用,留着纯属“占着茅坑不拉屎”
  • 模型量化:FP32、FP16用INT8、INT4替代,减少存储和计算
  • 参数共享:部分层级之间共享参数,减少存储空间,提升计算效率
  • 低秩分解:原理类同Lora,把大矩阵分解成low -rank 小矩阵,减少存储空间,提升计算效率
  • 参数搜索:使用算法或启发式方法来确定最佳的参数配置

   这么多方法,相比之下知识蒸馏是比较流行的,效果也是比较好的,这里尝试一下对secGPT-13B做做知识蒸馏(他家已经有secGPT-mini了,具体怎么的来的还不清楚);

  2、不论是现在的LLM,还是传统的机器学习,最终的目的都是提升泛化性能,提高鲁棒性,让模型经过训练后,在新的数据上也能有很好的表现。同理,知识蒸馏的最终目的也是让student在新数据集上的表现接近teacher,该怎么去模仿学习teacher了?

        

   所有的神经网络简化图如上,让student的输出逼近teacher,有这么三种方式:

  • 直接让student的output接近teacher,其他的不care(只看最终的结果,不管中间过程),这就是所谓的response-based knowledge
  • 为了更好地让student逼近teacher,只看结果可能还不够,还要严控过程,让hidden layer的效果也逼近teacher,这就是feather-based knowledge
  • 再进一步,融合了前面两项,再加上input layer,对整个全流程(input->hidden-output)做系统性地模仿学习,捕捉样本之间的关系和teacher模型地全局结构信息,叫relation-based knowledge;

    第一种response-based的方式最简单,不需要考虑student和teacher的网络结构是否相同,只看结果两个模型的输出loss,根据loss反向调整sudent的参数即可,所以完全可以使用现成的模型作为student继续fine-tune!我搜寻了一遍小模型,知名度高、对中文支持又比较好的:gpt2-chinese-cluecorpussmall(1.2亿参数)、gpt2-distil-chinese-cluecorpussmall、t5-base-chinese-cluecorpussmall(2.4亿参数),这里最终选择gpt2-distil-chinese-cluecorpussmall(从名称看,可能已经蒸馏过了,应该验证了蒸馏效果还行)作为student,和secGPT-13B组成cp做distillation!确定好模型后,接下来就是怎么实操落地啦!换句话说,怎么让teacher把所有的knowleage都准确无误、毫无保留地传授给student了?

  3、大家回忆一下自己小时候上学的场景:坐在教室里,有各种教材课程,然后听老师讲课。课后自己做作业,老师批改作业,做错的题还要重新做,直到做对为止,整个流程经年累月后自己可以从老师那学到大量的knowledge,这一整个流程在LLM的knowledge distil中能不能被借鉴了?同样的训练数据,分别经过student和teacher模型做前向传播计算,然后对比双方的输出,差异作为loss,student根据这个loss调整自己的参数,直到loss变小为止,原理是不是很简单?接着的问题又来了:

  • teacher和sutdent之间怎么计算loss?换句话说loss函数怎么设计?
  • 既然都有训练数据了,为什么不直接用这些数据fine-tune student模型?为什么还要用teacher去训练student?

    GPT模型decoder部分最后一步都是softmax,输出vocab中每个token的概率值。传统的training过程是训练语料中token作为one-hot形式的ground truth,让GPT的softmax输出和ground truth计算KL散度的差异,用这个差异做BP调整模型的参数。由于ground truth的token都是one-hot形式,也就是当前token的概率是1,其他token的概率是0,所以这种目标称为hard targets,流行的teacher模型本身就是通过这种hard target训练出来的!

     

  问题是teacher模型decoder的输出是所有token的概率组成的向量,比如[0.1,0.6,0.05,0.15,0.04,0.06]这种,不是one-hot的hard targets,这种soft targets能被用于训练student么?为啥不直接用原始训练数据的hard targets去微调student了

  还是以小时候上学读书为例:其实各种教材资料在市场上自己都能买到,为啥每天还要辛苦跑去学校读书了?为啥不自己在家里自学了?核心原因之一:教材内容展示的知识有限,有很多隐藏的知识是教材纸面上无法展示的!比如英语的发音,教材只能标识音标,具体怎么发音还是要靠老师教授和纠正;又比如数学定理的推导:有些教材推理过程并不详细,自学的时候可能看不懂为什么会从某些条件得到某些结论,期间还是要经验丰富的老师具体细化整个推导过程!总结一下,就是各种教材里面承载的明面知识有限,还有很多隐藏知识(dark knowledge)需要老师教授的!具体到LLM的知识蒸馏和训练,举例如下:

        

   上面这两图的数值,是2还是3了?是2还是7了?这就是hard target和soft target的本质区别!如果直接使用原始的数据训练student,用的就是hard target;如果使用teacher训练student,用的就是soft target!最核心的问题来了:为啥要用soft target训练student?soft target相比hard target,优势在哪

  仔细看上图,左边的数字确实是2,但是也有3的特征!右边数字确实是2,但也有7的特征,所以这两图也包含了其他数字的特征,所以如果直接简单粗暴地用hard target指定为2,那么3和7的特征是学不到的teacher训练时用了大量的语料,其他语料也有3和7的特征,所以这里用hard target也没啥问题;但知识蒸馏的场景下,训练语料是有限的,如果用hard target,student是无法提取3和7的特征的,会严重影响其他类别的判断所以使用soft target最大的好处:

  soft-target 指明类别之间的相对关系,可以让student学到其他类别的特征,大大提升模型的泛化性和鲁棒!这不就是所有机器学习的终极目标么?

  4、(1)确定了使用soft target后,就要确定loss的具体表达式了。为了提升泛化性和鲁棒性,对于负类不能像one-hot编码那样“赶尽杀绝”,需要适当给予一些概率,利于student模型提取特征,具体操作方式如下:

        

   T参数全名tempareture,用来调节概率的平滑性的。直观感受tempareture参数的作用:以logits = [-1,1,3,2,0.5]为例,不同的temperature对应不同的class probability,图示如下:

      

   看吧,T越大,各个不同类别的概率越接近,概率分布越平滑!

  (2)整个知识蒸馏的全流程:

         

  • teacher模型对input做feed forward计算,得到的结果经过softmax(t)后得到soft labels;
  • student模型同样对input做feed forward计算,然后分叉:
    • 和teacher一样,得到的结果经过softmax(t)后得到soft predictions;
    • 设置T=1,和原来的softmax效果一样,得到hard predictions;
  • soft labels和soft predictions,用于衡量teacher和student之间的差异
  • hard prediction和hard label,用于衡量student和ground truth之间的差异

  问题又特么来了:为啥要计算两个loss?这两个loss之间怎么取舍?

  • teacher虽然学识远远超过student,但是仍然有出错的可能,而这时候如果student在teacher的教授之外,可以同时参考到标准答案,就可以有效地降低被teacher偶尔“带偏”的可能性。
  • 既然又两个loss,那就人为分别设置权重呗,重要的loss权重高点,另一个权重低点!两个loss的权重是超参数,可以自由设置;

     (3)核心代码如下:

import json
from datasets import Dataset
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, BertTokenizer
import torch
from transformers import Trainer, TrainingArguments

# 加载数据
data_path = "/root/huggingface/data/"
data_files = ["distil.json"]

data = []
for file in data_files:
    with open(data_path + file, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))

# 将数据转换为 Hugging Face Dataset 格式
dataset = Dataset.from_list(data)

# 加载tokenizer
tokenizer = BertTokenizer.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall")

# 添加pad_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


def preprocess_function(examples):
    inputs = tokenizer(examples["query"], truncation=True, padding="max_length", max_length=512)
    outputs = tokenizer(examples["positive"], truncation=True, padding="max_length", max_length=512)
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": outputs["input_ids"],
        "labels_attention_mask": outputs["attention_mask"]
    }


tokenized_dataset = dataset.map(preprocess_function, batched=True)
split_datasets = tokenized_dataset.train_test_split(test_size=0.3)
train_dataset = split_datasets['train']
eval_dataset = split_datasets['test']

teacher_model = AutoModelForCausalLM.from_pretrained("/root/huggingface/secgpt", trust_remote_code=True,
                                                     device_map="cpu")
student_model = GPT2LMHeadModel.from_pretrained("/root/huggingface/gpt2-distil-chinese-cluecorpussmall")

# 确保教师模型和学生模型的词汇表大小一致
student_model.resize_token_embeddings(len(tokenizer))
teacher_model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir="/root/huggingface/SecGPT_distil",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir='./logs',
)


class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)

        student_logits = outputs.logits
        teacher_logits = teacher_outputs.logits.detach()

        # 确保 student_logits 和 teacher_logits 的形状一致
        if student_logits.shape != teacher_logits.shape:
            raise ValueError(
                f"Student logits shape {student_logits.shape} does not match teacher logits shape {teacher_logits.shape}")

        loss_fct = torch.nn.KLDivLoss(reduction="batchmean")
        temperature = 2.0
        alpha = 0.5
        beta = 0.5

        student_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
        teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)

        distillation_loss = loss_fct(student_probs, teacher_probs) * (temperature ** 2)

        # 计算student loss
        labels = inputs["labels"]
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct_student = torch.nn.CrossEntropyLoss()
        student_loss = loss_fct_student(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # 结合loss
        loss = alpha * distillation_loss + beta * student_loss

        return (loss, outputs) if return_outputs else loss


trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
)

trainer.train()

student_model.save_pretrained("/root/huggingface/SecGPT_distil")
tokenizer.save_pretrained("/root/huggingface/SecGPT_distil")

  训练样本的数据格式:

{"query": "frida是什么?", "positive": "Frida是一款基于python + javascript 的hook框架,适用于android/ios/linux/win/osx等平台。Frida的动态代码执行功能,主要是在它的核心引擎Gum中用C语言来实现的"}
{"query": "怎么使用IDA?", "positive": "1、安装IDA   2、用IDA打开二进制文件,可以使用F5将汇编反编译成C语言伪代码   3、可以直接调试伪代码了解二进制代码逻辑"}
{"query": "怎么脱壳?", "positive": "对于一代、二代壳,可以直接使用frida dexdump从内存把正常的dex代码dump到磁盘"}

  我这里的样本少,明显不够,效果不也好,后续还要继续努力收集数据啊.......

  (4)  https://techdiylife.github.io/blog/topic.html?category2=t05&blogid=0031  这里有模型推理和训练需要耗费显存的现成方案:直接用accelerate评估。命令如下:

accelerate estimate-memory baichuan-inc/Baichuan-13B-Base --trust-remote-code

  列举了4种不同量化单位的耗费显存大小:

 

 

 

参考:

1、https://blog.csdn.net/qq_52572775/article/details/138467295?spm=1001.2014.3001.5501  知识蒸馏Knowledge Distillation

2、https://www.jiqizhixin.com/articles/2024-03-18   LLM知识蒸馏最新综述

3、https://zhuanlan.zhihu.com/p/102038521  知识蒸馏经典之作

4、https://blog.csdn.net/jclian91/article/details/133896540  使用知识蒸馏提升模型推理性能

5、https://intellabs.github.io/distiller/knowledge_distillation.html  Knowledge Distillation

6、常见推理优化技术方案:

 

posted @ 2024-07-16 17:00  第七子007  阅读(908)  评论(0编辑  收藏  举报