构建RAG应用-day03: Chroma入门 本地embedding 智谱embedding

Chroma入门

使用chroma构建向量数据库。使用了两种embedding模型,可供自己选择。
本地embedding:SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
封装智谱embedding使得其可以在langchain中使用。

import os
from dotenv import load_dotenv, find_dotenv
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from embed import ZhipuAIEmbeddings

_ = load_dotenv(find_dotenv())


# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
# os.environ["HTTP_PROXY"] = 'http://127.0.0.1:7890'

# 获取folder_path下所有文件路径,储存在file_paths里
def generate_path(folder_path: str = '../data_base/knowledge_db') -> list:
    file_paths = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)
    return file_paths


def generate_loaders(file_paths: list) -> list:
    loaders = []
    for file_path in file_paths:
        file_type = file_path.split('.')[-1]
        if file_type == 'pdf':
            loaders.append(PyMuPDFLoader(file_path))
        elif file_type == 'md':
            loaders.append(UnstructuredMarkdownLoader(file_path))
    return loaders


def exec_load(loaders: list) -> list:
    texts = []
    for loader in loaders:
        texts.extend(loader.load())
    return texts


def slice_docs(texts):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=50)
    return text_splitter.split_documents(texts)


class VectorDB:
    # embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    embedding = ZhipuAIEmbeddings()
    persist_directory = '../data_base/vector_db/chroma'
    slice = 20

    def __init__(self, sliced_docs: list = None):
        assert sliced_docs is not None
        self.vectordb = Chroma.from_documents(
            documents=sliced_docs[:self.slice],  # 为了速度,只选择前 20 个切分的 doc 进行生成;使用千帆时因QPS限制,建议选择前 5 个doc
            embedding=self.embedding,
            persist_directory=self.persist_directory  # 允许我们将persist_directory目录保存到磁盘上
        )

    def persist(self):
        self.vectordb.persist()
        print(f"向量库中存储的数量:{self.vectordb._collection.count()}")

    def sim_search(self, query, k=3):
        sim_docs = self.vectordb.similarity_search(query, k=k)
        for i, sim_doc in enumerate(sim_docs, start=1):
            print(f"检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return sim_docs

    def mmr_search(self, query, k=3):
        mmr_docs = self.vectordb.max_marginal_relevance_search(query, k=k)
        for i, sim_doc in enumerate(mmr_docs, start=1):
            print(f"MMR 检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return mmr_docs


if __name__ == '__main__':
    # 读取目录下的所有文件路径
    file_paths = generate_path()
    # 根据文件生成加载器
    loaders = generate_loaders(file_paths)
    # 执行文档加载
    texts = exec_load(loaders)
    # 切分文档
    sliced_docs = slice_docs(texts)
    # 构建向量数据库
    vdb = VectorDB(sliced_docs)
    # 向量持久化存储
    vdb.persist()
    # 定义问题
    question = "什么是大语言模型"
    # 相似度检索
    vdb.sim_search(question)
    # 最大边际相关性(MMR) 检索
    vdb.mmr_search(question)

langchain embedding封装

需要一个智谱APIkey,官网注册并且实名认证即可:智谱AI开放平台 (bigmodel.cn)

from __future__ import annotations
import logging
from typing import Dict, List, Any
from dotenv import load_dotenv, find_dotenv
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)
_ = load_dotenv(find_dotenv())


# 在 Python 中,root_validator 是 Pydantic 模块中一个用于自定义数据校验的装饰器函数。root_validator 用于在校验整个数据模型之前对整个数据模型进行自定义校验,以确保所有的数据都符合所期望的数据结构。
class ZhipuAIEmbeddings(BaseModel, Embeddings):
    """`Zhipuai Embeddings` embedding models."""

    client: Any
    """`zhipuai.ZhipuAI"""

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """
        实例化ZhipuAI为values["client"]

        Args:

            values (Dict): 包含配置信息的字典,必须包含 client 的字段.
        Returns:

            values (Dict): 包含配置信息的字典。如果环境中有zhipuai库,则将返回实例化的ZhipuAI类;否则将报错 'ModuleNotFoundError: No module named 'zhipuai''.
        """
        from zhipuai import ZhipuAI
        values["client"] = ZhipuAI()
        return values

    def embed_query(self, text: str) -> List[float]:
        """
        生成输入文本的 embedding.

        Args:
            texts (str): 要生成 embedding 的文本.

        Return:
            embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表.
        """
        embeddings = self.client.embeddings.create(
            model="embedding-2",
            input=text
        )
        return embeddings.data[0].embedding

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        生成输入文本列表的 embedding.
        Args:
            texts (List[str]): 要生成 embedding 的文本列表.

        Returns:
            List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
        """
        return [self.embed_query(text) for text in texts]

posted @ 2024-04-23 00:24  passion2021  阅读(1002)  评论(0编辑  收藏  举报