构建RAG应用-day03: Chroma入门 本地embedding 智谱embedding
Chroma入门
使用chroma构建向量数据库。使用了两种embedding模型,可供自己选择。
本地embedding:SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
封装智谱embedding使得其可以在langchain中使用。
import os
from dotenv import load_dotenv, find_dotenv
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from embed import ZhipuAIEmbeddings
_ = load_dotenv(find_dotenv())
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
# os.environ["HTTP_PROXY"] = 'http://127.0.0.1:7890'
# 获取folder_path下所有文件路径,储存在file_paths里
def generate_path(folder_path: str = '../data_base/knowledge_db') -> list:
file_paths = []
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
file_paths.append(file_path)
return file_paths
def generate_loaders(file_paths: list) -> list:
loaders = []
for file_path in file_paths:
file_type = file_path.split('.')[-1]
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file_path))
elif file_type == 'md':
loaders.append(UnstructuredMarkdownLoader(file_path))
return loaders
def exec_load(loaders: list) -> list:
texts = []
for loader in loaders:
texts.extend(loader.load())
return texts
def slice_docs(texts):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=50)
return text_splitter.split_documents(texts)
class VectorDB:
# embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
embedding = ZhipuAIEmbeddings()
persist_directory = '../data_base/vector_db/chroma'
slice = 20
def __init__(self, sliced_docs: list = None):
assert sliced_docs is not None
self.vectordb = Chroma.from_documents(
documents=sliced_docs[:self.slice], # 为了速度,只选择前 20 个切分的 doc 进行生成;使用千帆时因QPS限制,建议选择前 5 个doc
embedding=self.embedding,
persist_directory=self.persist_directory # 允许我们将persist_directory目录保存到磁盘上
)
def persist(self):
self.vectordb.persist()
print(f"向量库中存储的数量:{self.vectordb._collection.count()}")
def sim_search(self, query, k=3):
sim_docs = self.vectordb.similarity_search(query, k=k)
for i, sim_doc in enumerate(sim_docs, start=1):
print(f"检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
return sim_docs
def mmr_search(self, query, k=3):
mmr_docs = self.vectordb.max_marginal_relevance_search(query, k=k)
for i, sim_doc in enumerate(mmr_docs, start=1):
print(f"MMR 检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
return mmr_docs
if __name__ == '__main__':
# 读取目录下的所有文件路径
file_paths = generate_path()
# 根据文件生成加载器
loaders = generate_loaders(file_paths)
# 执行文档加载
texts = exec_load(loaders)
# 切分文档
sliced_docs = slice_docs(texts)
# 构建向量数据库
vdb = VectorDB(sliced_docs)
# 向量持久化存储
vdb.persist()
# 定义问题
question = "什么是大语言模型"
# 相似度检索
vdb.sim_search(question)
# 最大边际相关性(MMR) 检索
vdb.mmr_search(question)
langchain embedding封装
需要一个智谱APIkey,官网注册并且实名认证即可:智谱AI开放平台 (bigmodel.cn)
from __future__ import annotations
import logging
from typing import Dict, List, Any
from dotenv import load_dotenv, find_dotenv
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
logger = logging.getLogger(__name__)
_ = load_dotenv(find_dotenv())
# 在 Python 中,root_validator 是 Pydantic 模块中一个用于自定义数据校验的装饰器函数。root_validator 用于在校验整个数据模型之前对整个数据模型进行自定义校验,以确保所有的数据都符合所期望的数据结构。
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""`Zhipuai Embeddings` embedding models."""
client: Any
"""`zhipuai.ZhipuAI"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
实例化ZhipuAI为values["client"]
Args:
values (Dict): 包含配置信息的字典,必须包含 client 的字段.
Returns:
values (Dict): 包含配置信息的字典。如果环境中有zhipuai库,则将返回实例化的ZhipuAI类;否则将报错 'ModuleNotFoundError: No module named 'zhipuai''.
"""
from zhipuai import ZhipuAI
values["client"] = ZhipuAI()
return values
def embed_query(self, text: str) -> List[float]:
"""
生成输入文本的 embedding.
Args:
texts (str): 要生成 embedding 的文本.
Return:
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表.
"""
embeddings = self.client.embeddings.create(
model="embedding-2",
input=text
)
return embeddings.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
生成输入文本列表的 embedding.
Args:
texts (List[str]): 要生成 embedding 的文本列表.
Returns:
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
"""
return [self.embed_query(text) for text in texts]