langchain Chatchat 学习实践(二)——运行机制源码解析

langchain chatchat的简介就不多说了,大家可以去看github官网介绍,虽然当前版本停止了更新,下个版本还没有出来,但作为学习还是很好的。

一、关键启动过程:

1、start_main_server 入口

2、run_controller 启动fastchat controller 端口20001

3、run_openai_api启动fastchat对外提供的类似openai接口的服务,端口20000

4、run_model_worker 创建fastchat的model_worker,其中又执行了以下过程:

         4.1、create_model_worker_app,根据配置文件,创建并初始化对应的model_workder,初始化过程中,model_worker会通过self.init_heart_beat()将自己注册到fastchat controller中,以供fastchat管理调用。在创建每个model_worker之前,都会执行一次from fastchat.serve.base_model_worker import app,由于是多进程创建,因而每次都会执行base_model_worker中的代码:

worker = None
logger = None

app = FastAPI()
这样每次创建的app都是不同的,最后create_model_worker_app方法取出model_work对应的fastaip对象app,将app返回。

        4.2 、uvicorn.run(app, host=host, port=port, log_level=log_level.lower()),启动模型对应的model_workder服务,这里的app来自model_workder的app。

二、chat过程

1、app.post("/chat/chat",
             tags=["Chat"],
             summary="与llm模型对话(通过LLMChain)",
             )(chat)
2、本地模型LLM对话
model = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            callbacks=callbacks,
        )
get_ChatOpenAI:
model = ChatOpenAI(
        streaming=streaming,
        verbose=verbose,
        callbacks=callbacks,
        openai_api_key=config.get("api_key", "EMPTY"),
        openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
        model_name=model_name,
        temperature=temperature,
        max_tokens=max_tokens,
        openai_proxy=config.get("openai_proxy"),
        **kwargs
    )
在这里指定了fastchat的openai_api接口地址,这样就获得了指定接口地址的langchain ChatOpenAI对象
然后创建LLMChain:
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
后面省略
3、在线模型LLM对话
在线模型的调用并没有直接发起,还是和上面一样,通过获取ChatOpenAI对象,来和fastchat进行交互,但是fastchat是不支持自定义调用在线模型的,langchain chatchat是怎么实现的呢?
原来,对应在线模型调用,langchain chatchat还是通过类似创建本地模型一样创建model_worker,但是对model_worker进行了继承,交互部分进行了重写,如qwen在线调用:
class QwenWorker(ApiModelWorker):
而ApiModelWorker来自BaseModelWorker,BaseModelWorker就是fastchat的worker_model的基类。(本地模型实例化时用的ModelWorker本身也是继承自BaseModelWorker)
class ApiModelWorker(BaseModelWorker):
    DEFAULT_EMBED_MODEL: str = None # None means not support embedding

    def __init__(
        self,
        model_names: List[str],
        controller_addr: str = None,
        worker_addr: str = None,
        context_len: int = 2048,
        no_register: bool = False,
        **kwargs,
    ):
        kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
        kwargs.setdefault("model_path", "")
        kwargs.setdefault("limit_worker_concurrency", 5)
        super().__init__(model_names=model_names,
                        controller_addr=controller_addr,
                        worker_addr=worker_addr,
                        **kwargs)
        import fastchat.serve.base_model_worker
        import sys
        self.logger = fastchat.serve.base_model_worker.logger
        # 恢复被fastchat覆盖的标准输出
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__

        new_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(new_loop)

        self.context_len = context_len
        self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
        self.version = None

        if not no_register and self.controller_addr:
            self.init_heart_beat()


    def count_token(self, params):
        prompt = params["prompt"]
        return {"count": len(str(prompt)), "error_code": 0}

    def generate_stream_gate(self, params: Dict):
        self.call_ct += 1

        try:
            prompt = params["prompt"]
            if self._is_chat(prompt):
                messages = self.prompt_to_messages(prompt)
                messages = self.validate_messages(messages)
            else: # 使用chat模仿续写功能,不支持历史消息
                messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]

            p = ApiChatParams(
                messages=messages,
                temperature=params.get("temperature"),
                top_p=params.get("top_p"),
                max_tokens=params.get("max_new_tokens"),
                version=self.version,
            )
            for resp in self.do_chat(p):
                yield self._jsonify(resp)
        except Exception as e:
            yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})

    def generate_gate(self, params):
        try:
            for x in self.generate_stream_gate(params):
                ...
            return json.loads(x[:-1].decode())
        except Exception as e:
            return {"error_code": 500, "text": str(e)}


    # 需要用户自定义的方法

    def(self, params: ApiChatParams) -> Dict:
        '''
        执行Chat的方法,默认使用模块里面的chat函数。
        要求返回形式:{"error_code": int, "text": str}
        '''
        return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}

    # def do_completion(self, p: ApiCompletionParams) -> Dict:
    #     '''
    #     执行Completion的方法,默认使用模块里面的completion函数。
    #     要求返回形式:{"error_code": int, "text": str}
    #     '''
    #     return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}

    def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
        '''
        执行Embeddings的方法,默认使用模块里面的embed_documents函数。
        要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
        '''
        return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}

    def get_embeddings(self, params):
        # fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
        # 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
        print("get_embedding")
        print(params)

    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
        raise NotImplementedError

    def validate_messages(self, messages: List[Dict]) -> List[Dict]:
        '''
        有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
        之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
        '''
        return messages


    # help methods
    @property
    def user_role(self):
        return self.conv.roles[0]

    @property
    def ai_role(self):
        return self.conv.roles[1]

    def _jsonify(self, data: Dict) -> str:
        '''
        将chat函数返回的结果按照fastchat openai-api-server的格式返回
        '''
        return json.dumps(data, ensure_ascii=False).encode() + b"\0"

    def _is_chat(self, prompt: str) -> bool:
        '''
        检查prompt是否由chat messages拼接而来
        TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
        '''
        key = f"{self.conv.sep}{self.user_role}:"
        return key in prompt

    def prompt_to_messages(self, prompt: str) -> List[Dict]:
        '''
        将prompt字符串拆分成messages.
        '''
        result = []
        user_role = self.user_role
        ai_role = self.ai_role
        user_start = user_role + ":"
        ai_start = ai_role + ":"
        for msg in prompt.split(self.conv.sep)[1:-1]:
            if msg.startswith(user_start):
                if content := msg[len(user_start):].strip():
                    result.append({"role": user_role, "content": content})
            elif msg.startswith(ai_start):
                if content := msg[len(ai_start):].strip():
                    result.append({"role": ai_role, "content": content})
            else:
                raise RuntimeError(f"unknown role in msg: {msg}")
        return result

    @classmethod
    def can_embedding(cls):
        return cls.DEFAULT_EMBED_MODEL is not None

  从代码中可以看到ApiModelWorker重写了generate_stream_gate,并且调用了do_chat方法,该方法要求子类去实现实际的chat过程。我们再回到class QwenWorker(ApiModelWorker):

def do_chat(self, params: ApiChatParams) -> Dict:
        import dashscope
        params.load_config(self.model_names[0])
        if log_verbose:
            logger.info(f'{self.__class__.__name__}:params: {params}')

        gen = dashscope.Generation()
        responses = gen.call(
            model=params.version,
            temperature=params.temperature,
            api_key=params.api_key,
            messages=params.messages,
            result_format='message',  # set the result is message format.
            stream=True,
        )

        for resp in responses:
            if resp["status_code"] == 200:
                if choices := resp["output"]["choices"]:
                    yield {
                        "error_code": 0,
                        "text": choices[0]["message"]["content"],
                    }
            else:
                data = {
                    "error_code": resp["status_code"],
                    "text": resp["message"],
                    "error": {
                        "message": resp["message"],
                        "type": "invalid_request_error",
                        "param": None,
                        "code": None,
                    }
                }
                self.logger.error(f"请求千问 API 时发生错误:{data}")
                yield data

  至此,qwen在线模型完成了调用。

三、总结

不得不说,这种设计还是很精妙的,借助fastchat,不仅实现了fastchat支持的几个本地大模型的调用,对于在线模型,即使不同的在线模型有不同的api接口定义,但只需要去定义实现一个新的继承ApiModelWorker的类,就可以屏蔽掉接口之间的差异,通过fastchat对齐接口,统一对外提供类openai api接口服务,这样在langchain不做修改的情况下,langchain就可以正常调用市面上各类接口迥异的在线大模型。

三、后续计划

1、Agent应用实践

posted @ 2024-03-19 17:30  郑某  阅读(1387)  评论(0编辑  收藏  举报