【LangChain】How to create a custom Memory class 如何自定义一个记忆类
How to create a custom Memory class 如何自定义一个记忆类
本文主要自定义了一个在LangChain中使用的Memory类
原文:How to create a custom Memory class
翻译
尽管在LangChain中有了一些预定义好的记忆类型,但是还是很有可能会有人想为自己的应用添加自己的记忆类型。这个笔记会介绍怎么添加。
在这个笔记中,我们会给ConversationChain
添加一个自定义的记忆类型。为了添加这个自定义记忆类,我们需要import基础的记忆类然后创建它的子类。
from langchain import OpenAI, ConversationChain
from langchain.schema import BaseMemory
from pydantic import BaseModel
from typing import List, Dict, Any
在这个例子中,我们将写一个自定义记忆类,这个类使用spacy提取实体,并将实体信息保存在一个简单的哈希表中。接着,在进行会话的过程中,我们会关注input的文本,提取所有实体,并将有关他们的所有信息放入上下文中。
- 请注意,该实现非常简单和脆弱,在生产环境中可能没有用处。其目的是展示您可以添加自定义记忆。
为此,我们需要spaCy
spaCy(简单介绍)
一个可以快速上手的nlp开发库,简单来讲,这里用到的spaCy就是先加载一个语言模型,之后把一个句子放进去跑一遍,同时完成了好几个nlp任务,包括分词、词性标注等,之后结果放在了一个类doc类中。
快速入门教程:使用spaCy做进阶自然语言处理
自定义记忆例子
# !pip install spacy
# !python -m spacy download en_core_web_lg
import spacy
nlp = spacy.load("en_core_web_lg")
class SpacyEntityMemory(BaseMemory, BaseModel):
"""为了保存实体信息的记忆类"""
# 定义用来保存实体信息的字典
entities: dict = {}
# 定义用来筛选添加到prompt中的实体信息的key
memory_key: str = "entities"
def clear(self):
self.entities = {}
@property
def memory_variables(self) -> List[str]:
"""定义加进prompt的变量"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""加载记忆变量,本例子中为entities"""
# 得到input文本并且用spacy跑一遍
doc = nlp(inputs[list(inputs.keys())[0]])
# 提取实体的已知信息,如果存在的话
entities = [
self.entities[str(ent)] for ent in doc.ents if str(ent) in self.entities
]
# 返回实体的综合信息,用来放进上下文中
return {self.memory_key: "\n".join(entities)}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""将本次对话的上下文保存到缓冲区"""
# 得到input文本并且用spacy跑一遍
text = inputs[list(inputs.keys())[0]]
doc = nlp(text)
# 对于提到的每个实体,将信息保存到字典中
for ent in doc.ents:
ent_str = str(ent)
if ent_str in self.entities:
self.entities[ent_str] += f"\n{text}"
else:
self.entities[ent_str] = text
现在我们定义一个prompt来接收实体信息和用户input
from langchain.prompts.prompt import PromptTemplate
template = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.
Relevant entity information:
{entities}
Conversation:
Human: {input}
AI:"""
prompt = PromptTemplate(input_variables=["entities", "input"], template=template)
然后现在我们把它们放在一起!
llm = OpenAI(temperature=0)
conversation = ConversationChain(
llm=llm, prompt=prompt, verbose=True, memory=SpacyEntityMemory()
)
在第一个例子中,对于没有任何预先知识的Harrison,"Relevant entity information"字段是空的。
conversation.predict(input="Harrison likes machine learning")
> Entering new ConversationChain chain...
Prompt after formatting:
The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.
Relevant entity information:
Conversation:
Human: Harrison likes machine learning
AI:
> Finished ConversationChain chain.
" That's great to hear! Machine learning is a fascinating field of study. It involves using algorithms to analyze data and make predictions. Have you ever studied machine learning, Harrison?"
现在是第二个例子,我们可以看到它存入了Harrison的信息。
conversation.predict(
input="What do you think Harrison's favorite subject in college was?"
)
> Entering new ConversationChain chain...
Prompt after formatting:
The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.
Relevant entity information:
Harrison likes machine learning
Conversation:
Human: What do you think Harrison's favorite subject in college was?
AI:
> Finished ConversationChain chain.
' From what I know about Harrison, I believe his favorite subject in college was machine learning. He has expressed a strong interest in the subject and has mentioned it often.'
请再次注意,该实现非常简单和脆弱,在生产环境中可能并不实用。其目的是展示您可以添加自定义记忆。
个人测试例子
我希望得到一个自定义记忆能够将每次GPT3.5生成的JSON中的某一键值。
在本例子中的JSON格式为:
{
"current_step": 当前步骤名;
"result_info":
{
"step": 默认步骤中有几个步骤,step列表中就生成几个json
[{
"name":步骤名,默认步骤中生成的步骤名;
"ok": 是否向用户确认完毕,是/否;
}, ...];
"is_over": 是否都向用户确认完毕,是/否;
"next_step": 若"is_over"为"是",则写下一步骤名,若"is_over"为"否",则为"";
};
"reply": 会议促进者的发言,必须有内容;
}
我想得到的键值为 "reply"
,故创建新的 RecentKConversationMemory
类。
*注意此处的 RecentKConversationMemory
类泛用性十分差,仅限于在我这个任务中可以使用。
from langchain.schema import BaseMemory
from typing import List, Dict, Any
class RecentKConversationMemory(BaseMemory):
"""保存最近5轮对话中JSON文件特定键值的记忆类"""
# 定义用来保存最近对话的列表,默认为空
recent_conversations: list = []
# 定义用来筛选添加到prompt中的实体信息的key,默认为"recent_conversations"
memory_key: str = "recent_conversations"
# 定义前缀,默认为"Human"和"AI"
human_prefix: str = "Human"
ai_prefix: str = "AI"
# 定义保留轮数,默认5
k: int = 5
def clear(self):
self.recent_conversations = []
@property
def memory_variables(self) -> List[str]:
"""定义加入prompt的变量"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""加载记忆变量,这里为特定键值的内容"""
recent_json_value = ""
# 遍历最近k轮对话
for conversation in self.recent_conversations[-self.k:]:
# 检查是否存在键值为"reply"的内容
if "reply" in eval(conversation["output"]["response"]).keys():
human = f"{self.human_prefix}: " + conversation["input"]['input']
ai = f"{self.ai_prefix}: " + eval(conversation["output"]["response"])["reply"]
recent_json_value += "\n" + "\n".join([human, ai])
return {self.memory_key: recent_json_value}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""将本次对话的上下文保存到缓冲区"""
# 每轮对话的输入和输出作为一个会话
conversation = {
"input": inputs,
"output": outputs
}
# 如果超出5个就把最开始的conversation pop出去
self.recent_conversations.append(conversation)
if len(self.recent_conversations) > self.k:
self.recent_conversations.pop(0)
测试
from langchain.prompts.prompt import PromptTemplate
template = """
你将扮演一个优秀的判断精准的会议促进者,现在需要你参与进会议中来
此处为自己的prompt
重点是接下来的格式所以就将过于长的prompt省掉了
【生成JSON格式】
[
"current_step": 当前步骤名;
"result_info":类型为JSON
[
"end_condition": 参加者发言中提到的终了条件;
"reason": 能否达标的理由;
"is_ok": 终了条件是否达标,是/否;
"meeting_type": 如果"is_ok"为"是",就把判断好的会议类型写下来;
"is_next": 可否进行下一步骤,在本步骤,如果'can_type'为'是'则为'是',是/否;
"next_step": 若"is_next"为"是",则写下一步骤名,若"is_next"为"否",则为"";
]
"reply": 会议促进者的发言,必须有内容;
]
###停止生成###
【对话历史】
{recent_conversations}
【当前对话】
参会者:{input}
促进者JSON:"""
prompt = PromptTemplate(input_variables=["recent_conversations", "input"], template=template)
chat = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
llm_chain = ConversationChain(
llm=chat,
prompt=prompt,
# 此处设置为2,目的是方便快速看到结果以及检验对错
memory=RecentKConversationMemory(ai_prefix="会议促进者", human_prefix="参会者", k=2),
verbose=True,
)
llm_chain.predict(input="开始吧")
此处只截取了【对话历史】以后的结果
【对话历史】
【当前对话】
参会者:开始吧
促进者JSON:
> Finished chain.
{\n "current_step": "设置终了条件",\n "result_info": {\n "end_condition": "",\n "reason": "",\n "is_ok": "",\n "meeting_type": "",\n "is_next": "",\n "next_step": ""\n },\n "reply": "好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?"\n}
llm_chain.predict(input="不知道")
【对话历史】
参会者: 开始吧
会议促进者: 好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?
【当前对话】
参会者:不知道
促进者JSON:
> Finished chain.
{\n "current_step": "设置终了条件",\n "result_info": {\n "end_condition": "不知道",\n "reason": "无法明确表明结束的时候的状态",\n "is_ok": "否",\n "meeting_type": "",\n "is_next": "是",\n "next_step": "明确参会者"\n },\n "reply": "终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。"\n}
此处【对话历史】中正如一开始所想的那样只把reply给截了出来放在了记忆中,那么下边验证是否可以做到只保留两轮对话
llm_chain.predict(input="希望能获得几个关于研究室规则的看法")
【对话历史】
参会者: 开始吧
会议促进者: 好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?
参会者: 不知道
会议促进者: 终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。
【当前对话】
参会者:希望能获得几个关于研究室规则的看法
促进者JSON:
> Finished chain.
{\n "current_step": "设置终了条件",\n "result_info": {\n "end_condition": "希望能获得几个关于研究室规则的看法",\n "reason": "明确说明要达成人的某种状态",\n "is_ok": "是",\n "meeting_type": "信息收集",\n "is_next": "是",\n "next_step": "明确参会者"\n },\n "reply": "非常好,你提出了一个明确的终了条件,即希望能获得几个关于研究室规则的看法。这符合信息收集类型的会议。接下来,我们需要明确参会者,请问在场有多少人参加这次会议呢?"\n}
此时【对话历史】中已经有2轮对话的消息了,看接着对话是否会将最开始的“开始吧”给去掉
llm_chain.predict(input="3人")
【对话历史】
参会者: 不知道
会议促进者: 终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。
参会者: 希望能获得几个关于研究室规则的看法
会议促进者: 非常好,你提出了一个明确的终了条件,即希望能获得几个关于研究室规则的看法。这符合信息收集类型的会议。接下来,我们需要明确参会者,请问在场有多少人参加这次会议呢?
【当前对话】
参会者:3人
促进者JSON:
所以是成功了