从零学习大模型——使用GLM-4-9B-Chat + BGE-M3 + langchain + chroma建立的本地RAG应用(二)——将GLM-4-9B-Chat接入langchain

第一篇介绍了如何配置最基本的环境并下载了GLM-4-9B-Chat到本地,接下来我们试着将GLM-4-9B-Chat接入LangChain。

LangChain 是一个基于大型语言模型(LLM)开发应用程序的框架。

LangChain 简化了LLM应用程序生命周期的每个阶段:

LangChain提供了很多LLM的封装,内置了 OpenAI、LLAMA 等大模型的调用接口。具体方法可自行查阅,本教程中使用本地模型接入LangChain。

为了接入本地LLM,我们需要继承Langchain.llms.base.LLM 中的一个子类,重写其中的几个关键函数。

还是在上一篇所使用的 /root/autodl-tmp 目录,新建glm4LLM.py文件:

from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class ChatGLM4_LLM(LLM):
    # 基于本地 ChatGLM4 自定义 LLM 类
    tokenizer: AutoTokenizer = None
    model: AutoModelForCausalLM = None
    gen_kwargs: dict = None
        
    def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
        super().__init__()
        print("正在从本地加载模型...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        ).eval()
        print("完成本地模型的加载")
        
        if gen_kwargs is None:
            gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
        self.gen_kwargs = gen_kwargs
        
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None,
              **kwargs: Any) -> str:
        messages = [{"role": "user", "content": prompt}]
        model_inputs = self.tokenizer.apply_chat_template(
            messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
        )
        
        # 将input_ids移动到与模型相同的设备
        device = next(self.model.parameters()).device
        model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
        
        generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
        ]
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""
        return {
            "model_name": "glm-4-9b-chat",
            "max_length": self.gen_kwargs.get("max_length"),
            "do_sample": self.gen_kwargs.get("do_sample"),
            "top_k": self.gen_kwargs.get("top_k"),
        }

    @property
    def _llm_type(self) -> str:
        return "glm-4-9b-chat"

然后就可以进行简单的测试了,新建一个python文件testLLM.py

from glm4LLM import ChatGLM4_LLM
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
llm = ChatGLM4_LLM(model_name_or_path="/root/autodl-tmp/ZhipuAI/glm-4-9b-chat", gen_kwargs=gen_kwargs)
print(llm.invoke("你是谁"))

运行该文件,如果输出了回答代表已成功将llm接入LangChain

posted @ 2024-07-10 14:20  YTARO  阅读(274)  评论(0编辑  收藏  举报