langchain接入星火大模型(其他模型也可以参考)

首先来说LangChain是什么?不了解的可以点击下面的链接来查看下。

LangChain入门指南_故里_的博客-CSDN博客

然后在介绍一下星火认知大模型相关:
讯飞星火认知大模型
感兴趣的小伙伴可以了解一下,国内比较成熟的类GPT(我自己定义的,也不知道对不对)模型。

说一下大概需求,首先我是要用到功能是文章摘要,之前接入的是OpenAI的api接口(langchain中已经封装好了相关内容),其实只对模型的好用程度来说OpenAI确实要相较于市面上其他的模型都要更智能一点,哪怕是对中文来说,而且因为是自己调试,都没用到GPT-4,仅是3.5系列模型都更加优秀一些。但是对于开发者来说尤其是在公司进行开发还是有一些弊端的。首先是收费,对于企业合作不知道收费具体怎么样,但对开发者自己来说确实收费还是比较高的(而且我一直也没搞懂这个收费是个怎么个收法,虽然他们说是按照token,但是总感觉有时候收费少了有时候收费多了)。其次是网络环境,比较优秀的解法是外部亚马逊服务器部署相关服务。最后是token数量,这个是比较硬伤的东西。

于是国内的大模型就成了我较好的选择,我的需求不仅仅是简单的问答,而是需要结合prompt来使用,同时因为我的输入内容比较大需要借助langchain内部的map_reduce来对我的整个提问流程进行一个整合,所以进行了星火Spark 接入langchain,其余模型也可以参考这个模板进行函数替换,之后会尝试自己部署模型调试以及优化,希望有喜欢机器学习的小伙伴一起成长。话不多说,上代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
import websocket
import langchain
import logging
from config import SPARK_APPID, SPARK_API_KEY, SPARK_API_SECRET
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from typing import Optional, List, Dict, Mapping, Any
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache
  
logging.basicConfig(level=logging.INFO)
# 启动llm的缓存
langchain.llm_cache = InMemoryCache()
result_list = []
  
  
def _construct_query(prompt, temperature, max_tokens):
    data = {
        "header": {
            "app_id": SPARK_APPID,
            "uid": "1234"
        },
        "parameter": {
            "chat": {
                "domain": "general",
                "random_threshold": temperature,
                "max_tokens": max_tokens,
                "auditing": "default"
            }
        },
        "payload": {
            "message": {
                "text": [
                    {"role": "user", "content": prompt}
                ]
            }
        }
    }
    return data
  
  
def _run(ws, *args):
    data = json.dumps(
        _construct_query(prompt=ws.question, temperature=ws.temperature, max_tokens=ws.max_tokens))
    # print (data)
    ws.send(data)
  
  
def on_error(ws, error):
    print("error:", error)
  
  
def on_close(ws):
    print("closed...")
  
  
def on_open(ws):
    thread.start_new_thread(_run, (ws,))
  
  
def on_message(ws, message):
    data = json.loads(message)
    code = data['header']['code']
    # print(data)
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        result_list.append(content)
        if status == 2:
            ws.close()
            setattr(ws, "content", "".join(result_list))
            print(result_list)
            result_list.clear()
  
  
class Spark(LLM):
    '''
    根据源码解析在通过LLMS包装的时候主要重构两个部分的代码
    _call 模型调用主要逻辑,输入问题,输出模型相应结果
    _identifying_params 返回模型描述信息,通常返回一个字典,字典中包括模型的主要参数
    '''
  
    gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # spark官方模型提供api接口
    host = urlparse(gpt_url).netloc  # host目标机器解析
    path = urlparse(gpt_url).path  # 路径目标解析
    max_tokens = 1024
    temperature = 0.5
  
    # ws = websocket.WebSocketApp(url='')
  
    @property
    def _llm_type(self) -> str:
        # 模型简介
        return "Spark"
  
    def _get_url(self):
        # 获取请求路径
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
  
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"
  
        signature_sha = hmac.new(SPARK_API_SECRET.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()
  
        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  
        authorization_origin = f'api_key="{SPARK_API_KEY}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  
        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        url = self.gpt_url + '?' + urlencode(v)
        return url
  
    def _post(self, prompt):
        #模型请求响应
        websocket.enableTrace(False)
        wsUrl = self._get_url()
        ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error,
                                    on_close=on_close, on_open=on_open)
        ws.question = prompt
        setattr(ws, "temperature", self.temperature)
        setattr(ws, "max_tokens", self.max_tokens)
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
        return ws.content if hasattr(ws, "content") else ""
  
    def _call(self, prompt: str,
              stop: Optional[List[str]] = None) -> str:
        # 启动关键的函数
        content = self._post(prompt)
        # content = "这是一个测试"
        return content
  
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """
        Get the identifying parameters.
        """
        _param_dict = {
            "url": self.gpt_url
        }
        return _param_dict
  
  
if __name__ == "__main__":
    llm = Spark(temperature=0.9)
    # data =json.dumps(llm._construct_query(prompt="你好啊", temperature=llm.temperature, max_tokens=llm.max_tokens))
    # print (data)
    # print (type(data))
    result = llm("你好啊", stop=["you"])
    print(result)

  

posted @   TopJocker  阅读(2176)  评论(2编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示