LlamaIndex中的CustomLLM(在线加载模型)
一.使用 Flask 将模型封装为 REST 接口
主要是将 complete()和 stream_complete()方法封装为 REST 接口,如下所示:
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM
app = Flask(__name__)
class QwenModel:
def __init__(self, pretrained_model_name_or_path):
# self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, device_map="cpu", trust_remote_code=True) # CPU方式加载模型
# self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, device_map="cpu", trust_remote_code=True) # CPU方式加载模型
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, device_map="cuda", trust_remote_code=True) # GPU方式加载模型
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, device_map="cuda", trust_remote_code=True) # GPU方式加载模型
self.model = self.model.float()
def generate_completion(self, prompt):
# inputs = self.tokenizer.encode(prompt, return_tensors="pt") # CPU方式加载模型
inputs = self.tokenizer.encode(prompt, return_tensors="pt").cuda() # GPU方式加载模型
outputs = self.model.generate(inputs, max_length=128)
response = self.tokenizer.decode(outputs[0])
return response
pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\20230925_Qwen\Qwen-1_8B'
qwen_model = QwenModel(pretrained_model_name_or_path)
@app.route('/complete', methods=['POST'])
def complete():
data = request.get_json()
prompt = data.get('prompt', '')
result = qwen_model.generate_completion(prompt)
return jsonify({'text': result})
@app.route('/stream_complete', methods=['POST'])
def stream_complete():
data = request.get_json()
prompt = data.get('prompt', '')
result = list(qwen_model.generate_completion(prompt))
return jsonify(result)
if __name__ == "__main__":
app.run(debug=False, port=5050, host='0.0.0.0')
二.通过 requests.post 请求方式调用接口
主要是通过 requests.post 请求方式来实现 complete()和 stream_complete()函数,如下所示:
from typing import Any
import requests
from llama_index import ServiceContext, SimpleDirectoryReader, SummaryIndex
from llama_index.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata
from llama_index.llms.base import llm_completion_callback
class QwenCustomLLM(CustomLLM):
context_window: int = 8192
num_output: int = 128
model_name: str = "Qwen-1_8B"
base_url: str = "http://127.0.0.1:5050"
tokenizer: object = None
model: object = None
def __init__(self):
super().__init__()
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
data = {'prompt': prompt}
response = requests.post(f'{self.base_url}/complete', json=data)
result = response.json()
return CompletionResponse(text=result['text'])
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
data = {'prompt': prompt}
response = requests.post(f'{self.base_url}/stream_complete', json=data)
result = response.json()
for token in result:
yield CompletionResponse(text=token, delta=token)
if __name__ == "__main__":
llm = QwenCustomLLM()
# 方式1:本地加载模型方式进行调用
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local:L:/20230713_HuggingFaceModel/BAAI_bge-large-zh")
# TODO 方式2:调整embed_model为在线模型
# 思路:继承BaseEmbedding类,将加载Embedding模型部分封装为REST接口。可重点参考下OpenAIEmbedding类的实现。
# service_context = ServiceContext.from_defaults(llm=llm, embed_model=BgeLargeZhEmbedding())
documents = SimpleDirectoryReader("./data").load_data()
index = SummaryIndex.from_documents(documents, service_context=service_context)
query_engine = index.as_query_engine()
# 你能够像以前一样使用llm.complete和llm.stream_complete
response_complete = llm.complete("您好")
print(response_complete)
response_stream_complete = list(llm.stream_complete("您好"))
print(response_stream_complete)
response = query_engine.query("花未眠")
print(response)
上述代码在加载 Embedding 模型的时候还是从本地加载的,这部分也是可通过 REST 接口方式调用的。基本思路:继承 BaseEmbedding 类,将加载 Embedding 模型部分封装为 REST 接口。可重点参考下 OpenAIEmbedding 类的实现。
参考文献
[1] https://docs.llamaindex.ai/en/stable/
[2] https://github.com/run-llama/llama_index
[3] https://github.com/run-llama/llama_index/blob/main/llama_index/embeddings/init.py
[4] QWenCustomLLMOnline(本文源码):https://github.com/ai408/nlp-engineering/tree/main/知识工程-大语言模型/LlamaIndex 实战/自定义 LLM/QWenCustomLLMOnline
NLP工程化
1.本公众号以对话系统为中心,专注于Python/C++/CUDA、ML/DL/RL和NLP/KG/DS/LLM领域的技术分享。
2.本公众号Roadmap可查看飞书文档:https://z0yrmerhgi8.feishu.cn/wiki/Zpewwe2T2iCQfwkSyMOcgwdInhf
NLP工程化(公众号)
NLP工程化(星球号)