实施语义缓存以改进 RAG 系统
实施语义缓存以改进 RAG 系统
1.缓存介绍
在本笔记本中,我们将探索一个典型的 RAG 解决方案,其中我们将使用开源模型和向量数据库 Chroma DB。但是,我们将集成一个语义缓存系统,该系统将存储各种用户查询,并决定是否生成包含来自向量数据库或缓存的信息的提示。
语义缓存系统旨在识别相似或相同的用户请求。当找到匹配的请求时,系统会从缓存中检索相应的信息,从而减少从原始源获取它的需要。
由于比较考虑了请求的语义含义,因此它们不必完全相同,系统就可以将它们识别为同一个问题。它们可以以不同的方式表达或包含不准确之处,无论是印刷错误还是句子结构,我们都可以确定用户实际上正在请求相同的信息。
例如,像“法国的首都是什么?”、告诉我法国首都的名字?和“法国的首都是什么?”这样的查询都传达了相同的意图,应该被识别为同一个问题。
虽然模型的响应可能因第二个示例中对简洁答案的请求而有所不同,但从向量数据库检索到的信息应该是相同的。这就是为什么我将缓存系统放在用户和向量数据库之间,而不是用户和大型语言模型之间。
大多数指导您创建 RAG 系统的教程都是为单用户使用而设计的,旨在在测试环境中运行。换句话说,在笔记本中,与本地向量数据库交互并进行 API 调用或使用本地存储的模型。
当尝试将其中一个模型转换为生产时,这种架构很快就会变得不足,因为它们可能会遇到数十到数千个重复请求。
提高性能的一种方法是通过一个或多个语义缓存。此缓存保留先前请求的结果,在解析新请求之前,它会检查之前是否收到过类似的请求。如果是,它不会重新执行该过程,而是从缓存中检索信息。
在 RAG 系统中,有两点很耗时:
- 检索用于构建丰富提示的信息
- 调用大型语言模型以获取响应
在这两个点中,都可以实现语义缓存系统,我们甚至可以有两个缓存,每个点一个。
将其放置在模型的响应点可能会导致对所获得响应的影响丧失。我们的缓存系统可以将“用 10 个字解释法国大革命”和“用 100 个字解释法国大革命”视为相同的查询。如果我们的缓存系统存储模型响应,用户可能会认为他们的指令没有被准确遵循。
但这两个请求都需要相同的信息来丰富提示。这就是我选择将语义缓存系统放置在用户请求和从矢量数据库检索信息之间的主要原因。
但是,这是一个设计决策。根据响应和系统请求的类型,它可以放在一个点或另一个点。很明显,缓存模型响应可以节省最多的时间,但正如我已经解释过的,这是以失去用户对响应的影响为代价的。
2.导入并加载库。
首先,我们需要安装必要的 Python 包。
sentence transformers.:这个库对于将句子转换为固定长度的向量(也称为嵌入)是必需的。
xformers:它是一个提供库和实用程序的包,以方便使用转换器模型。我们需要安装它以避免在使用模型和嵌入时出现错误。
chromadb:这是我们的矢量数据库。ChromaDB 易于使用且开源,可能是用于存储嵌入的最常用的矢量数据库。
accelerate:需要在 GPU 中运行模型。
!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1
import numpy as np
import pandas as pd
3.加载数据集
由于我们在一个自由且有限的空间内工作,并且只能使用几 GB 的内存,因此我使用变量 MAX_ROWS 限制了数据集中要使用的行数。
#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
from datasets import load_dataset
data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')
data = data.to_pandas()
data["id"]=data.index
data.head(10)
qtype | Question | Answer | id | |
---|---|---|---|---|
0 | susceptibility | Who is at risk for Lymphocytic Choriomeningiti… | LCMV infections can occur after exposure to fr… | 0 |
1 | symptoms | What are the symptoms of Lymphocytic Choriomen… | LCMV is most commonly recognized as causing ne… | 1 |
2 | susceptibility | Who is at risk for Lymphocytic Choriomeningiti… | Individuals of all ages who come into contact … | 2 |
3 | exams and tests | How to diagnose Lymphocytic Choriomeningitis (… | During the first phase of the disease, the mos… | 3 |
4 | treatment | What are the treatments for Lymphocytic Chorio… | Aseptic meningitis, encephalitis, or meningoen… | 4 |
5 | prevention | How to prevent Lymphocytic Choriomeningitis (L… | LCMV infection can be prevented by avoiding co… | 5 |
6 | information | What is (are) Parasites – Cysticercosis ? | Cysticercosis is an infection caused by the la… | 6 |
7 | susceptibility | Who is at risk for Parasites – Cysticercosis? ? | Cysticercosis is an infection caused by the la… | 7 |
8 | exams and tests | How to diagnose Parasites – Cysticercosis ? | If you think that you may have cysticercosis, … | 8 |
9 | treatment | What are the treatments for Parasites – Cystic… | Some people with cysticercosis do not need to … | 9 |
MAX_ROWS = 15000
DOCUMENT="Answer"
TOPIC="qtype"
ChromaDB 要求数据具有唯一标识符。我们可以使用此语句来实现,它将创建一个名为 Id 的新列。
#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)
4.导入和配置矢量数据库
我将使用最流行的开源矢量数据库 ChromaDB。
首先,我们需要导入 ChromaDB,然后从 chromadb.config 模块导入 Settings 类。该类允许我们更改 ChromaDB 系统的设置并自定义其行为。
import chromadb
from chromadb.config import Settings
现在我们只需要指明矢量数据库的存储路径。
chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")
5.填充和查询 ChromaDB 数据库
ChromaDB 中的数据存储在集合中。如果集合存在,我们需要删除它。
在接下来的几行中,我们将通过调用上面创建的 chroma_client 中的 create_collection 函数来创建集合。
collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
是时候将数据添加到集合中了。使用 add 函数,我们至少需要通知文档、元数据和 ID。
在文档中,我们存储大文本,它是每个数据集中的不同列。
在元数据中,我们可以通知主题列表。
在 ID 中,我们需要为每行通知一个唯一的标识符。它必须是唯一的!我正在使用 MAX_ROWS 范围创建 ID。
collection.add(
documents=subset_data[DOCUMENT].tolist(),
metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
ids=[f"id{x}" for x in range(MAX_ROWS)],
)
一旦我们获得了数据库中的信息,我们就可以查询它,并请求符合我们需求的数据。搜索是在文档内容内进行的,它不会查找确切的单词或短语。结果将基于搜索词和文档内容之间的相似性。
元数据不用于搜索,但可用于在初始搜索后过滤或优化结果。
让我们定义一个函数来查询 ChromaDB 数据库。
def query_database(query_text, n_results=10):
results = collection.query(query_texts=query_text, n_results=n_results )
return results
6. 创建语义缓存系统
为了实现缓存系统,我们将使用 Faiss,这是一个允许将嵌入存储在内存中的库。它与 Chroma 所做的非常相似,但没有持久性。
为此,我们将创建一个名为 semantic_cache 的类,它将与其自己的编码器一起工作,并为用户提供执行查询所需的功能。
在这个类中,我们首先查询 Faiss(缓存),如果返回的结果高于指定的阈值,它将从缓存中返回结果。否则,它将从 Chroma 数据库中获取结果。
缓存存储在 .json 文件中。
!pip install -q faiss-cpu==1.8.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 62.3 MB/s eta 0:00:00
import faiss
from sentence_transformers import SentenceTransformer
import time
import json
此函数初始化语义缓存。
它采用 FlatLS 索引,这可能不是最快的,但非常适合小型数据集。根据要缓存的数据的特征和预期的数据集大小,可以使用其他索引,例如 HNSW 或 IVF。
def init_cache():
index = faiss.IndexFlatL2(768)
if index.is_trained:
print('Index trained')
# Initialize Sentence Transformer model
encoder = SentenceTransformer('all-mpnet-base-v2')
return index, encoder
在retrieve_cache函数中,如果需要在会话间重用缓存,则会从磁盘检索.json文件。
def retrieve_cache(json_file):
try:
with open(json_file, 'r') as file:
cache = json.load(file)
except FileNotFoundError:
cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}
return cache
store_cache 函数将包含缓存数据的文件保存到磁盘。
def store_cache(json_file, cache):
with open(json_file, 'w') as file:
json.dump(cache, file)
这些函数将在 SemanticCache 类中使用,该类包括搜索函数及其初始化函数。
尽管 ask 函数有大量代码,但其目的却非常简单。它在缓存中查找与用户刚刚提出的问题最接近的问题。
然后,检查它是否在指定的阈值内。如果是肯定的,它直接从缓存中返回响应;否则,它调用 query_database 函数从 ChromaDB 中检索数据。
我使用了欧几里得距离而不是余弦,后者在向量比较中被广泛使用。这种选择是基于欧几里得距离是 Faiss 使用的默认度量这一事实。虽然也可以计算余弦距离,但这样做会增加复杂性,可能不会对最终结果产生重大影响。
class semantic_cache:
def __init__(self, json_file="cache_file.json", thresold=0.35):
# Initialize Faiss index with Euclidean distance
self.index, self.encoder = init_cache()
# Set Euclidean distance threshold
# a distance of 0 means identicals sentences
# We only return from cache sentences under this thresold
self.euclidean_threshold = thresold
self.json_file = json_file
self.cache = retrieve_cache(self.json_file)
def ask(self, question: str) -> str:
# Method to retrieve an answer from the cache or generate a new one
start_time = time.time()
try:
#First we obtain the embeddings corresponding to the user question
embedding = self.encoder.encode([question])
# Search for the nearest neighbor in the index
self.index.nprobe = 8
D, I = self.index.search(embedding, 1)
if D[0] >= 0:
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
row_id = int(I[0][0])
print('Answer recovered from Cache. ')
print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
print(f'response_text: ' + self.cache['response_text'][row_id])
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.3f} seconds")
return self.cache['response_text'][row_id]
# Handle the case when there are not enough results
# or Euclidean distance is not met, asking to chromaDB.
answer = query_database([question], 1)
response_text = answer['documents'][0][0]
self.cache['questions'].append(question)
self.cache['embeddings'].append(embedding[0].tolist())
self.cache['answers'].append(answer)
self.cache['response_text']