LLM的推理部署:vLLM
vLLM是一个快速且易于使用的LLM推理和服务库
vLLM的快速性体现在:
- 最先进的服务吞吐量
- 通过PagedAttention有效管理注意力机制KV的内存
- 连续的批次处理请求
- 使用CUDA/HIP图快速执行模型
- 量化:GPTQ、AWQ、INT4、INT8、FP8
- CUDA内核优化,包括FlashAttention和FlashInfer的集成
- 推测行解码
- 分块预填充
vLLM的灵活、易使用体现在:
- 与HuggingFace模型无缝集成
- 高吞吐量服务与各种解码算法,包括并行采样、波束搜索等
- 用于分布式推理的张量并行性和管道并行性支持
- 流式输出
- OpenAI兼容的API服务器
- 支持NVIDIA GPU、AMD CPU和GPU、Intel CPU和GPU、PowerPC CPU、TPU和AWS Neuron
- 前缀缓存支持
- 多lora支持
vLLM的安装
安装需求
- Linux操作环境,目前vLLM仅支持Linux
- Python:3.8 - 3.12
- GPU:算力7.0以上的设备(V100、T4、A100、L4、H100等)
使用pip安装
conda create -n my_env python=3.10 conda activate my_env pip install vLLM
vLLM推理
使用vLLM库实现LLM的简单推理部署,代码如下
从vLLM类中导入LLM和SamplingParams,LLM类是使用vLLM引擎进行离线推理的主要类,SamplingParams类用于指定采样过程的参数
from vllm import LLM, SamplingParams llm = LLM(model=model_dir) # 模型名称或存储路径
# 定义输入提示列表和生成的采样参数,采样温度设置为0.8,核采样概率设置为0.95, 最大输出长度设置为128
prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_token=128) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
vLLM + outlines结构化生成
https://github.com/outlines-dev/outlines
LLM在一些场景下要求生成的文本以指定格式输出,可以使用outlines库实现基于regex、json、grammar等实现结构化生成
pip install outlines
使用正则化表达式约束生成格式
from outlines import models, generate from vllm import LLM, SampleParams model_dir = r"" llm = LLM(model=model_dir, tokenizer=model_dir) tokenizer = llm.get_tokenizer() sampling_params = SamplingParams(temperature=1.0, top_p=0.9, max_tokens=64) text = "question1: ..." system_prompt = "从三个角度分析如下问题" user_prompt = f"从三个角度分析如下问题{text},输出格式为:角度一: xxx; 角度二:xxx; 角度三:xxx" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] llm = models.VLLM(llm) prompt = tokenizer.apply_chat_template( messages, tokenizer=False, add_generation_prompt=True ) regex_str = r"角度一:(.*?)'\n'角度二:(.*?)'\n'角度三:(.*?)'\n'" generator = generate.regex(llm, regex_str) output = generator(prompt, sampling_params=sampling_params) print(output)