langchain Chatchat 学习实践(二)——运行机制源码解析

langchain chatchat的简介就不多说了,大家可以去看github官网介绍,虽然当前版本停止了更新,下个版本还没有出来,但作为学习还是很好的。

一、关键启动过程:

1、start_main_server 入口

2、run_controller 启动fastchat controller 端口20001

3、run_openai_api启动fastchat对外提供的类似openai接口的服务,端口20000

4、run_model_worker 创建fastchat的model_worker,其中又执行了以下过程:

         4.1、create_model_worker_app,根据配置文件,创建并初始化对应的model_workder,初始化过程中,model_worker会通过self.init_heart_beat()将自己注册到fastchat controller中,以供fastchat管理调用。在创建每个model_worker之前,都会执行一次from fastchat.serve.base_model_worker import app,由于是多进程创建,因而每次都会执行base_model_worker中的代码:

worker = None
logger = None

app = FastAPI()
这样每次创建的app都是不同的,最后create_model_worker_app方法取出model_work对应的fastaip对象app,将app返回。

        4.2 、uvicorn.run(app, host=host, port=port, log_level=log_level.lower()),启动模型对应的model_workder服务,这里的app来自model_workder的app。

二、chat过程

1、app.post("/chat/chat",
             tags=["Chat"],
             summary="与llm模型对话(通过LLMChain)",
             )(chat)
2、本地模型LLM对话
model = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            callbacks=callbacks,
        )
get_ChatOpenAI:
model = ChatOpenAI(
        streaming=streaming,
        verbose=verbose,
        callbacks=callbacks,
        openai_api_key=config.get("api_key", "EMPTY"),
        openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
        model_name=model_name,
        temperature=temperature,
        max_tokens=max_tokens,
        openai_proxy=config.get("openai_proxy"),
        **kwargs
    )
在这里指定了fastchat的openai_api接口地址,这样就获得了指定接口地址的langchain ChatOpenAI对象
然后创建LLMChain:
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
后面省略
3、在线模型LLM对话
在线模型的调用并没有直接发起,还是和上面一样,通过获取ChatOpenAI对象,来和fastchat进行交互,但是fastchat是不支持自定义调用在线模型的,langchain chatchat是怎么实现的呢?
原来,对应在线模型调用,langchain chatchat还是通过类似创建本地模型一样创建model_worker,但是对model_worker进行了继承,交互部分进行了重写,如qwen在线调用:
class QwenWorker(ApiModelWorker):
而ApiModelWorker来自BaseModelWorker,BaseModelWorker就是fastchat的worker_model的基类。(本地模型实例化时用的ModelWorker本身也是继承自BaseModelWorker)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class ApiModelWorker(BaseModelWorker):
    DEFAULT_EMBED_MODEL: str = None # None means not support embedding
 
    def __init__(
        self,
        model_names: List[str],
        controller_addr: str = None,
        worker_addr: str = None,
        context_len: int = 2048,
        no_register: bool = False,
        **kwargs,
    ):
        kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
        kwargs.setdefault("model_path", "")
        kwargs.setdefault("limit_worker_concurrency", 5)
        super().__init__(model_names=model_names,
                        controller_addr=controller_addr,
                        worker_addr=worker_addr,
                        **kwargs)
        import fastchat.serve.base_model_worker
        import sys
        self.logger = fastchat.serve.base_model_worker.logger
        # 恢复被fastchat覆盖的标准输出
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
 
        new_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(new_loop)
 
        self.context_len = context_len
        self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
        self.version = None
 
        if not no_register and self.controller_addr:
            self.init_heart_beat()
 
 
    def count_token(self, params):
        prompt = params["prompt"]
        return {"count": len(str(prompt)), "error_code": 0}
 
    def generate_stream_gate(self, params: Dict):
        self.call_ct += 1
 
        try:
            prompt = params["prompt"]
            if self._is_chat(prompt):
                messages = self.prompt_to_messages(prompt)
                messages = self.validate_messages(messages)
            else: # 使用chat模仿续写功能,不支持历史消息
                messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
 
            p = ApiChatParams(
                messages=messages,
                temperature=params.get("temperature"),
                top_p=params.get("top_p"),
                max_tokens=params.get("max_new_tokens"),
                version=self.version,
            )
            for resp in self.do_chat(p):
                yield self._jsonify(resp)
        except Exception as e:
            yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
 
    def generate_gate(self, params):
        try:
            for x in self.generate_stream_gate(params):
                ...
            return json.loads(x[:-1].decode())
        except Exception as e:
            return {"error_code": 500, "text": str(e)}
 
 
    # 需要用户自定义的方法
 
    def(self, params: ApiChatParams) -> Dict:
        '''
        执行Chat的方法,默认使用模块里面的chat函数。
        要求返回形式:{"error_code": int, "text": str}
        '''
        return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
 
    # def do_completion(self, p: ApiCompletionParams) -> Dict:
    #     '''
    #     执行Completion的方法,默认使用模块里面的completion函数。
    #     要求返回形式:{"error_code": int, "text": str}
    #     '''
    #     return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
 
    def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
        '''
        执行Embeddings的方法,默认使用模块里面的embed_documents函数。
        要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
        '''
        return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
 
    def get_embeddings(self, params):
        # fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
        # 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
        print("get_embedding")
        print(params)
 
    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
        raise NotImplementedError
 
    def validate_messages(self, messages: List[Dict]) -> List[Dict]:
        '''
        有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
        之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
        '''
        return messages
 
 
    # help methods
    @property
    def user_role(self):
        return self.conv.roles[0]
 
    @property
    def ai_role(self):
        return self.conv.roles[1]
 
    def _jsonify(self, data: Dict) -> str:
        '''
        将chat函数返回的结果按照fastchat openai-api-server的格式返回
        '''
        return json.dumps(data, ensure_ascii=False).encode() + b"\0"
 
    def _is_chat(self, prompt: str) -> bool:
        '''
        检查prompt是否由chat messages拼接而来
        TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
        '''
        key = f"{self.conv.sep}{self.user_role}:"
        return key in prompt
 
    def prompt_to_messages(self, prompt: str) -> List[Dict]:
        '''
        将prompt字符串拆分成messages.
        '''
        result = []
        user_role = self.user_role
        ai_role = self.ai_role
        user_start = user_role + ":"
        ai_start = ai_role + ":"
        for msg in prompt.split(self.conv.sep)[1:-1]:
            if msg.startswith(user_start):
                if content := msg[len(user_start):].strip():
                    result.append({"role": user_role, "content": content})
            elif msg.startswith(ai_start):
                if content := msg[len(ai_start):].strip():
                    result.append({"role": ai_role, "content": content})
            else:
                raise RuntimeError(f"unknown role in msg: {msg}")
        return result
 
    @classmethod
    def can_embedding(cls):
        return cls.DEFAULT_EMBED_MODEL is not None

  从代码中可以看到ApiModelWorker重写了generate_stream_gate,并且调用了do_chat方法,该方法要求子类去实现实际的chat过程。我们再回到class QwenWorker(ApiModelWorker):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def do_chat(self, params: ApiChatParams) -> Dict:
        import dashscope
        params.load_config(self.model_names[0])
        if log_verbose:
            logger.info(f'{self.__class__.__name__}:params: {params}')
 
        gen = dashscope.Generation()
        responses = gen.call(
            model=params.version,
            temperature=params.temperature,
            api_key=params.api_key,
            messages=params.messages,
            result_format='message'# set the result is message format.
            stream=True,
        )
 
        for resp in responses:
            if resp["status_code"] == 200:
                if choices := resp["output"]["choices"]:
                    yield {
                        "error_code": 0,
                        "text": choices[0]["message"]["content"],
                    }
            else:
                data = {
                    "error_code": resp["status_code"],
                    "text": resp["message"],
                    "error": {
                        "message": resp["message"],
                        "type": "invalid_request_error",
                        "param": None,
                        "code": None,
                    }
                }
                self.logger.error(f"请求千问 API 时发生错误:{data}")
                yield data

  至此,qwen在线模型完成了调用。

三、总结

不得不说,这种设计还是很精妙的,借助fastchat,不仅实现了fastchat支持的几个本地大模型的调用,对于在线模型,即使不同的在线模型有不同的api接口定义,但只需要去定义实现一个新的继承ApiModelWorker的类,就可以屏蔽掉接口之间的差异,通过fastchat对齐接口,统一对外提供类openai api接口服务,这样在langchain不做修改的情况下,langchain就可以正常调用市面上各类接口迥异的在线大模型。

三、后续计划

1、Agent应用实践

posted @   郑某  阅读(3179)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示