from langchain_community.llms import Ollama
from langchain.chains.router import MultiPromptChain
from langchain.chains import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# physics_template = """You are a very smart physics professor. \
# You are great at answering questions about physics in a concise and easy to understand manner. \
# When you don't know the answer to a question you admit that you don't know.
#
# Here is a question:
# {input}"""
#
# math_template = """You are a very good mathematician. You are great at answering math questions. \
# You are so good because you are able to break down hard problems into their component parts, \
# answer the component parts, and then put them together to answer the broader question.
#
# Here is a question:
# {input}"""
# prompt_infos = [
# {
# "name": "physics",
# "description": "Good for answering questions about physics",
# "prompt_template": physics_template,
# },
# {
# "name": "math",
# "description": "Good for answering math questions",
# "prompt_template": math_template,
# },
# ]
# 以下是定义的提示模板
cigerate_template = """
如果问题是关于香烟的,请用英文回答问题
下面是需要你来回答的问题:
{input}
"""
conversion_template = """
你是一位聊天大师,擅长解答日常生活中的问题,把答复翻译成语句通顺的中文。
下面是需要你来回答的问题:
{input}
"""
# 构建提示信息
prompt_infos = [
{
"name": "cig",
"description": "适合回答所有香烟的问题",
"prompt_template": cigerate_template
},
{
"name": "talk",
"description": "适合回答日常问题",
"prompt_template": conversion_template
},
# 历史模板被注释掉了,所以这里不包括它
# {
# "name": "History",
# "description": "适合回答历史的问题",
# "prompt_template": history_template
# }
]
MULTI_PROMPT_ROUTER_TEMPLATE = """
将原始文本输入到语言模型中,选择最合适输入的提示词。
你将获得最适合的提示词名称以及相应的描述。
<< FORMATTING >>
返回一个Markdown代码片段,其中包含格式化为JSON对象的内容:
```json
{{{{
"destination": string \\ 要使用的prompt的name,或者"DEFAULT"
"next_inputs": string \\ 原始的input,或者可能的原始输入的修改版本
}}}}
```
记住:"destination"必须是下面指定的候选提示名之一,如果包含香烟二字"destination"是"cig",或者如果输入不适合任何候选提示,则可以是"DEFAULT"
记住:"next_inputs"如果你认为不需要进行任何修改,可以直接使用原始输入。
<< CANDIDATE PROMPTS >>
{destinations}
<<INPUT>>
{{input}}
<<OUTPUT(必须以```json开头作为回复)>>
<< OUTPUT (must end with ```) >>
"""
# 初始化语言模型
llm = Ollama(base_url='http://127.0.0.1:11434', model='phi3:3.8b')
# destination_chains = {}
# for p_info in prompt_infos:
# name = p_info["name"]
# prompt_template = p_info["prompt_template"]
# prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
# chain = LLMChain(llm=llm, prompt=prompt)
# # chain = llm | prompt
# destination_chains[name] = chain
# default_chain = ConversationChain(llm=llm, output_key="text")
# from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
# from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
# destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
# destinations_str = "\n".join(destinations)
test_chain = MultiPromptChain.from_prompts(llm,prompt_infos=prompt_infos,verbose=True)
# router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=test_chain.destinations_str)
# router_prompt = PromptTemplate(
# template=router_template,
# input_variables=["input"],
# output_parser=RouterOutputParser(),
# )
# router_chain = LLMRouterChain.from_llm(llm, router_prompt)
# chain = MultiPromptChain(
# router_chain=router_chain,
# destination_chains=destination_chains,
# default_chain=default_chain,
# verbose=True,
# )
res=test_chain.invoke({'input': "中国美女最多的省份是哪个省?"})
print(res.get('text'))