Fork me on GitHub

如何蒸馏 Deepseek-R1

如何蒸馏 Deepseek-R1

深度学习模型已经彻底改变了人工智能领域,但其庞大的规模和计算需求可能成为现实世界应用的瓶颈。模型蒸馏是一种强大的技术,通过将知识从大型复杂模型(教师)转移到较小、更高效的模型(学生)来解决这一挑战。

在这篇博客中,这里将介绍如何使用 LoRA (Low-Rank Adaptation)等专门技术,将 DeepSeek-R1 的推理能力蒸馏成一个更小的模型,比如微软的 Phi-3-Mini。

什么是蒸馏?

蒸馏是一种机器学习技术,其中一个较小的模型(“学生”)被训练来模仿一个较大的预训练模型(“老师”)的行为。其目标是在显著降低计算成本和内存占用的同时,保留大部分教师的表现。

这个想法最早是在Geoffrey Hinton 关于知识蒸馏的开创性论文中提出的。它不是直接在原始数据上训练学生模型,而是从教师模型的输出或中间表示中学习。这实际上是受到了人类教育的启发。

为什么重要:

  • 成本效率:较小的模型需要更少的计算资源。
  • 速度:非常适合对延迟敏感的应用(例如 api、边缘设备)。
  • 专业化:在不重新训练巨头的情况下为特定领域量身定制模型。

蒸馏类型

有几种方法可以模拟蒸馏,每种方法都有自己的优点:

  1. 数据蒸馏
  • 在数据蒸馏中,教师模型生成合成数据或伪标签,然后用于训练学生模型。
  • 这种方法可以应用于广泛的任务,甚至是那些逻辑信息较少的任务(例如,开放式推理任务)。
  • Logits蒸馏
  • Logits 是应用 softmax 函数之前神经网络的原始输出分数。
  • 在logits 蒸馏中,学生模型被训练成匹配老师的logits,而不仅仅是最终的预测。
  • 这种方法保留了更多关于教师信心水平和决策过程的信息。
  • 特征蒸馏
  • 特征蒸馏涉及到将知识从教师模型的中间层传递给学生。
  • 通过对齐两个模型的隐表征,学生可以学习到更丰富、更抽象的特征。

Deepseek 的蒸馏模型

DeepSeek AI 发布了六个基于流行架构的蒸馏模型,如 Qwen

(Qwen,2024b)和 Llama (AI@Meta,2024他们直接使用 DeepSeek-R1

收集的 80 万样本对开源模型进行微调。

尽管比 DeepSeek-R1 小得多,但经过蒸馏过的模型在各种基准测试中表现出了令人印象深刻的性能,通常与更大的模型相匹配甚至超越。如下图所示

Deepseek 蒸馏模型基准(https://arxiv.org/html/2501.12948v1)

蒸馏自己的模型

  1. 特定任务优化
    预蒸馏模型在广泛的数据集上进行训练,以便在广泛的任务中表现良好。然而,现实世界的应用往往需要专业化
    示例场景:
    你正在构建一个财务预测聊天机器人。
    在这种情况下,使用 DeepSeek-R1 来生成金融数据集的推理痕迹(例如,股票价格预测,风险分析),并将这些知识蒸馏到一个已经知道金融细微差别的较小模型中(例如:financial - llm)。
  2. 规模成本效益
    虽然预蒸馏模型是有效的,但对于你的特定工作量来说,它们可能仍然是多余的。蒸馏自己的模型允许针对自己的确切资源约束进行优化
  3. 基准性能≠真实世界性能
    预蒸馏模型在基准测试上表现出色,但基准测试往往不能代表现实世界的任务。所以你经常需要一个模型,它在现实场景中的表现比任何预蒸馏模型都要好。
  4. 迭代改进

预蒸馏模型是静态的——它们不会随着时间的推移而改进。通过蒸馏自己的模型,你可以随着新数据的出现而不断完善它

蒸馏DeepSeek-R1 知识到自定义小模型

步骤 1:安装库

pip install -q torch transformers datasets accelerate bitsandbytes flash-attn --no-build-isolation

步骤 2:生成和格式化数据集

可以通过在你的环境中使用ollama 或任何其他部署框架部署deepseek-r1 来生成自定义的领域相关数据集。但是,对于本教程,这里将使用Magpie-Reasoning-V2数据集,其中包含由 DeepSeek-R1 生成的 250K CoT 推理样本。这些样本涵盖了数学推理、编码和一般问题解决等不同的任务。

数据集结构

每个样本包括:

  • 指令:任务描述(例如,“解决这道数学题”)。
  • 回应:DeepSeek-R1 的逐步推理(CoT)。例子:
{
 "instruction": "Solve for x: 2x + 5 = 15",
 "response": "<think>First, subtract 5 from both sides: 2x = 10. Then, divide by 2: x = 5.</think>"
}
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B", token="YOUR_HF_TOKEN")
dataset = dataset["train"]

# Format the dataset
def format_instruction(example):
 return {
 "text": (
 "<|user|>\n"
 f"{example['instruction']}\n"
 "<|end|>\n"
 "<|assistant|>\n"
 f"{example['response']}\n"
 "<|end|>"
        )
    }

formatted_dataset = dataset.map(format_instruction, batched=False, remove_columns=subset_dataset.column_names)
formatted_dataset = formatted_dataset.train_test_split(test_size=0.1)  # 90-10 train-test split

将数据集构造成 Phi-3 的聊天模板格式:

<|user|>:用户询问的开始。

<|assistant|>:模型响应的开始。

<|end|>:一轮结束。

每个 LLM 使用特定的指令跟随任务格式。将数据集与这种结构对齐可以确保模型学习到正确的会话模式。所以一定要根据你想要蒸馏的模型来格式化数据。

步骤 3:加载 Model 和 Tokenizer

为了增强模型的推理能力,这里向tokenizer添加特殊tokens <think>和</think>。

<think>:推理的开始。

</think>:推理结束。

这些tokens帮助模型学习生成结构化的、逐步的解决方案。

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "microsoft/phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

# Add custom tokens
CUSTOM_TOKENS = ["<think>", "</think>"]
tokenizer.add_special_tokens({"additional_special_tokens": CUSTOM_TOKENS})
tokenizer.pad_token = tokenizer.eos_token

# Load model with flash attention
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"
)
model.resize_token_embeddings(len(tokenizer))  # Resize for custom tokens

步骤 4:为高效微调配置 LoRA

LoRA 通过冻结基本模型和只训练小的适配器层来减少内存使用。

from peft import LoraConfig

peft_config = LoraConfig(
    r=8,  # Rank of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.2,  # Dropout rate
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Target attention layers
    bias="none",  # No bias terms
    task_type="CAUSAL_LM" # Task type
)

第 5 步:设置训练参数

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./phi-3-deepseek-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    learning_rate=2e-5,
    fp16=True,
    optim="paged_adamw_32bit",
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine"
)

第 6 步:训练模型

SFTTrainer 简化了指令遵循模型的监督微调。data_collator 对示例进行批处理,

peft_config 支持基于lora 的训练。

from trl import SFTTrainer
from transformers import DataCollatorForLanguageModeling

# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset["train"],
    eval_dataset=formatted_dataset["test"],
    data_collator=data_collator,
    peft_config=peft_config
)

# Start training
trainer.train()
trainer.save_model("./phi-3-deepseek-finetuned")
tokenizer.save_pretrained("./phi-3-deepseek-finetuned")

第 7 步:合并并保存最终模型

训练后,LoRA 适配器必须与base模型合并进行推理。这一步确保了模型可以在没有 PEFT 的情况下独立使用。

final_model = trainer.model.merge_and_unload()
final_model.save_pretrained("./phi-3-deepseek-finetuned-final")
tokenizer.save_pretrained("./phi-3-deepseek-finetuned-final")

第 8 步:推理

from transformers import pipeline

# Load fine-tuned model
model = AutoModelForCausalLM.from_pretrained(
 "./phi-3-deepseek-finetuned-final",
    device_map="auto",
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("./phi-3-deepseek-finetuned-final")
model.resize_token_embeddings(len(tokenizer))

# Create chat pipeline
chat_pipeline = pipeline(
 "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto"
)

# Generate response
prompt = """<|user|>
What's the probability of rolling a 7 with two dice?
<|end|>
<|assistant|>
"""

output = chat_pipeline(
    prompt,
    max_new_tokens=5000,
    temperature=0.7,
    do_sample=True,
    eos_token_id=tokenizer.eos_token_id
)

print(output[0]['generated_text'])

下图可以看到 phi 模型在蒸馏前后的响应。

question: what’s the probability of rolling a 7 with two dice?

问题:用两个骰子摇到 7 的概率是多少?

蒸馏前的推理:回答直白简洁。直接提供了计算答案的步骤。

蒸馏后的推理:蒸馏后的回答引入了一种更详细和结构化的方法,包括一个明确的“思考”部分,概述了思维过程和推理,这将对复杂问题产生准确的回答非常有帮助。

 

posted @   石头木  阅读(4609)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示