构建RAG应用-day06: 个人知识库助手项目
个人知识库助手
本文基于datawhale开源学习项目:llm-universe/docs/C6 at main · datawhalechina/llm-universe (github.com)
获取数据库
该项目llm-universe个人知识库助手选用 Datawhale 一些经典开源课程、视频(部分)作为示例,具体包括:
- 《机器学习公式详解》PDF版本
- 《面向开发者的 LLM 入门教程 第一部分 Prompt Engineering》md版本
- 《强化学习入门指南》MP4版本
- 以及datawhale总仓库所有开源项目的readme https://github.com/datawhalechina
这些知识库源数据放置在 /data_base/knowledge_db 目录下,用户也可以自己存放自己其他的文件。
1.下面讲一下如何获取 DataWhale 总仓库的所有开源项目的 readme ,用户可以通过先运行 project/database/test_get_all_repo.py 文件,用来获取 Datawhale 总仓库所有开源项目的 readme,代码如下:
import json
import requests
import os
import base64
import loguru
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# github token
TOKEN = 'your github token'
# 定义获取组织仓库的函数
def get_repos(org_name, token, export_dir):
headers = {
'Authorization': f'token {token}',
}
url = f'https://api.github.com/orgs/{org_name}/repos'
response = requests.get(url, headers=headers, params={'per_page': 200, 'page': 0})
if response.status_code == 200:
repos = response.json()
loguru.logger.info(f'Fetched {len(repos)} repositories for {org_name}.')
# 使用 export_dir 确定保存仓库名的文件路径
repositories_path = r'E:\django_project\law_chatbot\test\task6\repositories.txt'
with open(repositories_path, 'w', encoding='utf-8') as file:
for repo in repos:
file.write(repo['name'] + '\n')
return repos
else:
loguru.logger.error(f"Error fetching repositories: {response.status_code}")
loguru.logger.error(response.text)
return []
# 定义拉取仓库README文件的函数
def fetch_repo_readme(org_name, repo_name, token, export_dir):
headers = {
'Authorization': f'token {token}',
}
url = f'https://api.github.com/repos/{org_name}/{repo_name}/readme'
response = requests.get(url, headers=headers)
if response.status_code == 200:
readme_content = response.json()['content']
# 解码base64内容
readme_content = base64.b64decode(readme_content).decode('utf-8')
# 使用 export_dir 确定保存 README 的文件路径
repo_dir = os.path.join(export_dir, repo_name)
if not os.path.exists(repo_dir):
os.makedirs(repo_dir)
readme_path = os.path.join(repo_dir, 'README.md')
with open(readme_path, 'w', encoding='utf-8') as file:
file.write(readme_content)
else:
loguru.logger.error(f"Error fetching README for {repo_name}: {response.status_code}")
loguru.logger.error(response.text)
# 主函数
if __name__ == '__main__':
# 配置组织名称
org_name = 'datawhalechina'
# 配置 export_dir
export_dir = "./database/readme_db" # 请替换为实际的目录路径
# 获取仓库列表
repos = get_repos(org_name, TOKEN, export_dir)
# 打印仓库名称
if repos:
for repo in repos:
repo_name = repo['name']
# 拉取每个仓库的README
fetch_repo_readme(org_name, repo_name, TOKEN, export_dir)
# 清理临时文件夹
# if os.path.exists('temp'):
# shutil.rmtree('temp')
这里可能需要你自己的github token,获取方法如下:
1. 打开Github官方网站并登录您的账号。
2. 在右上角的菜单中,选择”Settings”。
3. 在设置页面中选择”Developer settings”选项卡。
4. 在左侧的菜单中选择”Personal access tokens”。
5. 点击”Generate new token”按钮来生成一个新的Token。
使用llm进行摘要处理
这些readme文件含有不少无关信息,我们使用llm进行摘要处理:
(原文档使用的openai==0.28,这里使用新版本的openai包)
import os
from dotenv import load_dotenv
import openai
from get_data import get_repos
from bs4 import BeautifulSoup
import markdown
import re
import time
from openai import OpenAI
import openai
# Load environment variables
load_dotenv()
TOKEN = 'your token'
# Set up the OpenAI API client
openai_api_key = os.environ["OPENAI_API_KEY"]
openai.base_url = 'https://api.chatanywhere.tech/v1'
# 过滤文本中链接防止大语言模型风控
def remove_urls(text):
# 正则表达式模式,用于匹配URL
url_pattern = re.compile(r'https?://[^\s]*')
# 替换所有匹配的URL为空字符串
text = re.sub(url_pattern, '', text)
# 正则表达式模式,用于匹配特定的文本
specific_text_pattern = re.compile(r'扫描下方二维码关注公众号|提取码|关注||回复关键词|侵权|版权|致谢|引用|LICENSE'
r'|组队打卡|任务打卡|组队学习的那些事|学习周期|开源内容|打卡|组队学习|链接')
# 替换所有匹配的特定文本为空字符串
text = re.sub(specific_text_pattern, '', text)
return text
# 抽取md中的文本
def extract_text_from_md(md_content):
# Convert Markdown to HTML
html = markdown.markdown(md_content)
# Use BeautifulSoup to extract text
soup = BeautifulSoup(html, 'html.parser')
return remove_urls(soup.get_text())
def generate_llm_summary(repo_name, readme_content, model):
prompt = f"1:这个仓库名是 {repo_name}. 此仓库的readme全部内容是: {readme_content}\
2:请用约200以内的中文概括这个仓库readme的内容,返回的概括格式要求:这个仓库名是...,这仓库内容主要是..."
openai.api_key = openai_api_key
# 具体调用
messages = [{"role": "system", "content": "你是一个人工智能助手"},
{"role": "user", "content": prompt}]
llm = OpenAI(base_url=openai.base_url, )
response = llm.chat.completions.create(
model=model,
messages=messages,
)
return response.choices[0].message.content
def main(org_name, export_dir, summary_dir, model):
repos = get_repos(org_name, TOKEN, export_dir)
# Create a directory to save summaries
os.makedirs(summary_dir, exist_ok=True)
for id, repo in enumerate(repos):
repo_name = repo['name']
readme_path = os.path.join(export_dir, repo_name, 'README.md')
print(repo_name)
if os.path.exists(readme_path):
with open(readme_path, 'r', encoding='utf-8') as file:
readme_content = file.read()
# Extract text from the README
readme_text = extract_text_from_md(readme_content)
# Generate a summary for the README
# 访问受限,每min一次
# time.sleep(60)
print('第' + str(id) + '条' + 'summary开始')
try:
summary = generate_llm_summary(repo_name, readme_text, model)
print(summary)
# Write summary to a Markdown file in the summary directory
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary\n\n")
summary_file.write(summary)
except openai.OpenAIError as e:
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary风控.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary风控\n\n")
summary_file.write("README内容风控。\n")
print(f"Error generating summary for {repo_name}: {e}")
# print(readme_text)
else:
print(f"文件不存在: {readme_path}")
# If README doesn't exist, create an empty Markdown file
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary不存在.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary不存在\n\n")
summary_file.write("README文件不存在。\n")
if __name__ == '__main__':
# 配置组织名称
org_name = 'datawhalechina'
# 配置 export_dir
export_dir = "./database/readme_db" # 请替换为实际readme的目录路径
summary_dir = "./data_base/knowledge_db/readme_summary" # 请替换为实际readme的概括的目录路径
model = "gpt-3.5-turbo" # deepseek-chat,gpt-3.5-turbo,moonshot-v1-8k
main(org_name, export_dir, summary_dir, model)
得到各个readme的摘要信息,共100条:
使用智谱ai构建向量数据库
import os
import sys
import re
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embed import ZhipuAIEmbeddings
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
# 首先实现基本配置
DEFAULT_DB_PATH = "data_base/knowledge_db/readme_summary"
DEFAULT_PERSIST_PATH = "./vector_db"
def get_files(dir_path):
file_list = []
for filepath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_list.append(os.path.join(filepath, filename))
return file_list
def file_loader(file, loaders):
if isinstance(file, tempfile._TemporaryFileWrapper):
file = file.name
if not os.path.isfile(file):
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
return
file_type = file.split('.')[-1]
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file))
elif file_type == 'md':
pattern = r"不存在|风控"
match = re.search(pattern, file)
if not match:
loaders.append(UnstructuredMarkdownLoader(file))
elif file_type == 'txt':
loaders.append(UnstructuredFileLoader(file))
return
def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
if embeddings == 'openai' or embeddings == 'm3e' or embeddings == 'zhipuai':
vectordb = create_db(files, persist_directory, embeddings)
return ""
def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
if embedding == "zhipuai":
return ZhipuAIEmbeddings(zhipuai_api_key=os.environ['ZHIPUAI_API_KEY'])
else:
raise ValueError(f"embedding {embedding} not support ")
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
"""
该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
参数:
file: 存放文件的路径。
embeddings: 用于生产 Embedding 的模型
返回:
vectordb: 创建的数据库。
"""
if files == None:
return "can't load empty file"
if type(files) != list:
files = [files]
loaders = []
[file_loader(file, loaders) for file in files]
docs = []
for loader in loaders:
if loader is not None:
docs.extend(loader.load())
# 切分文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=150)
split_docs = text_splitter.split_documents(docs)
if type(embeddings) == str:
embeddings = get_embedding(embeddings)
# 定义持久化路径
persist_directory = './vector_db/chroma'
# 加载数据库
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embeddings,
persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上
)
vectordb.persist()
return vectordb
def presit_knowledge_db(vectordb):
"""
该函数用于持久化向量数据库。
参数:
vectordb: 要持久化的向量数据库。
"""
vectordb.persist()
def load_knowledge_db(path, embeddings):
"""
该函数用于加载向量数据库。
参数:
path: 要加载的向量数据库路径。
embeddings: 向量数据库使用的 embedding 模型。
返回:
vectordb: 加载的数据库。
"""
vectordb = Chroma(
persist_directory=path,
embedding_function=embeddings
)
return vectordb
if __name__ == "__main__":
create_db(embeddings="zhipuai")
使用智谱ai构建问答链
import os
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
import openai
from embed import ZhipuAIEmbeddings
openai.base_url = 'https://api.chatanywhere.tech/v1'
chatgpt = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, base_url=openai.base_url)
persist_directory = 'vector_db/chroma'
embedding = ZhipuAIEmbeddings(zhipuai_api_key=os.environ['ZHIPUAI_API_KEY'])
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embedding
)
class Chat_QA_chain_self:
""""
带历史记录的问答链
- model:调用的模型名称
- temperature:温度系数,控制生成的随机性
- top_k:返回检索的前k个相似文档
- chat_history:历史记录,输入一个列表,默认是一个空列表
- history_len:控制保留的最近 history_len 次对话
- file_path:建库文件所在路径
- persist_path:向量数据库持久化路径
- appid:星火
- api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
- Spark_api_secret:星火秘钥
- Wenxin_secret_key:文心秘钥
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)
"""
def __init__(self, model: str, temperature: float = 0.0, top_k: int = 4, chat_history: list = [],
file_path: str = None, persist_path: str = None, appid: str = None, api_key: str = None,
Spark_api_secret: str = None, Wenxin_secret_key: str = None, embedding="openai",
embedding_key: str = None):
self.model = model
self.temperature = temperature
self.top_k = top_k
self.chat_history = chat_history
# self.history_len = history_len
self.file_path = file_path
self.persist_path = persist_path
self.appid = appid
self.api_key = api_key
self.Spark_api_secret = Spark_api_secret
self.Wenxin_secret_key = Wenxin_secret_key
self.embedding = embedding
self.embedding_key = embedding_key
# self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding, self.embedding_key)
self.vectordb = vectordb
def clear_history(self):
"清空历史记录"
return self.chat_history.clear()
def change_history_length(self, history_len: int = 1):
"""
保存指定对话轮次的历史记录
输入参数:
- history_len :控制保留的最近 history_len 次对话
- chat_history:当前的历史对话记录
输出:返回最近 history_len 次对话
"""
n = len(self.chat_history)
return self.chat_history[n - history_len:]
def answer(self, question: str = None, temperature=None, top_k=4):
""""
核心方法,调用问答链
arguments:
- question:用户提问
"""
if len(question) == 0:
return "", self.chat_history
if len(question) == 0:
return ""
if temperature == None:
temperature = self.temperature
# llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,
# self.Wenxin_secret_key)
llm = chatgpt
# self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
retriever = self.vectordb.as_retriever(search_type="similarity",
search_kwargs={'k': top_k}) # 默认similarity,k=4
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever
)
# print(self.llm)
result = qa({"question": question, "chat_history": self.chat_history}) # result里有question、chat_history、answer
answer = result['answer']
self.chat_history.append((question, answer)) # 更新历史记录
return self.chat_history # 返回本次回答和更新后的历史记录
if __name__ == '__main__':
question_1 = "给我介绍1个 Datawhale 的机器学习项目"
qa_chain = Chat_QA_chain_self(model="gpt-3.5-turbo")
result = qa_chain.answer(question=question_1)
print("大模型+知识库后回答 question_1 的结果:")
print(result[0][1])
查看效果: