langchain Chatchat 学习实践(四)——实现对Text2Sql的支持

这里记录一下langchain chatchat项目中的text2sql的实现思路。

1、SQLDatabaseChain链

SQLDatabaseChain是langchain框架自带的数据库自然语言交互工具,其内部通过sqlalchemy来获取数据库的表名和表结构、字段信息,然后将数据库的信息和用户的自然语言请求一起发送给大模型进行分析,让大模型返回sql语句后,执行sql,并返回执行结果。
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})

2、SQLDatabaseSequentialChain

SQLDatabaseChain会将数据库所有的表信息发送给大模型,如果数据库的表比较多且复杂,构造的prompt会很长,这样会超过一些大模型的token长度限制,而有些问题其实不需要把所有表发送给大模型,也会造成资源浪费。

SQLDatabaseSequentialChain会先进性一次判断,根据需求,结合表名,预测会用到哪些表,然后在将相关表和需求转给SQLDatabaseChain,这样可以更加精确的实现需求,节约token。

3、表名作用强调

SQLDatabaseSequentialChain有一个比较麻烦的问题,就是其预测过程严重依赖表名,因此如果你设计的表名难于被大模型理解用途,就会导致预测失败,后续的SQLDatabaseChain也必然会导致错误的结果。
即使只考虑SQLDatabaseChain,其执行的过程也会将表名构造到prompt中,因而表名如果难于理解,会对执行效果有很大负面影响。

由于langchain框架内部封装这两个工具,不能增加prompt模板输入参数变量,除非修改langchain源码。但是我们可以通过对query进行补充,明确告知大模型某些表的实际含义:

    #如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
    #由于langchain固定了输入参数,所以只能通过query传递额外的表说明
    if table_comments:
            TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
            table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])            
            query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"

4、read-only模式

text2sql固然带来了全新的人与数据库交互体验,但是也带来一定安全风险,在某些生产环境中,我们只希望提供只读权限给用户,禁止写入。
解决方法有三个:

(1)开发者自己将设置数据库只读账户,并提供给text2sql(强烈推荐)

(2)因为SQLDatabaseChain本身是用sqlalchemy实现的,因此可以通过添加拦截器,对写入操作进行拦截:

# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
    # List of SQL keywords that indicate a write operation
    write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
    # Check if the statement starts with any of the write operation keywords
    if any(statement.strip().lower().startswith(op) for op in write_operations):
        raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)

event.listen(db._engine, "before_cursor_execute", intercept_sql)

(3)但是拦截器这种方法会导致异常抛出,给用户交互带来不好的体验,解决思路就是在read-only模式下,让大模型先进行预测,预测该请求是否涉及写操作,如果预测会用到写操作,那么直接返回相关提示即可,后续流程终止,这样可以带来更加友好的体验:

READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode. 
Given an input question, determine if the related SQL can be executed in read-only mode.
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
Use the following format:

Answer: Final answer here

Question: {query}
"""

if read_only:
        # 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。
        READ_ONLY_PROMPT = PromptTemplate(
            input_variables=["query"],
            template=READ_ONLY_PROMPT_TEMPLATE,
        )
        read_only_chain = LLMChain(
            prompt=READ_ONLY_PROMPT,
            llm=llm,
        )
        read_only_result = read_only_chain.invoke(query)
        if "SQL cannot be executed normally" in read_only_result["text"]:
            return "当前数据库为只读状态,无法满足您的需求!"

完整示例代码:
Langchain-Chatchat/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py at dev · chatchat-space/Langchain-Chatchat (github.com)

posted @ 2024-06-14 15:34  郑某  阅读(106)  评论(0编辑  收藏  举报