实施语义缓存以改进 RAG 系统

实施语义缓存以改进 RAG 系统

1.缓存介绍

在本笔记本中,我们将探索一个典型的 RAG 解决方案,其中我们将使用开源模型和向量数据库 Chroma DB。但是,我们将集成一个语义缓存系统,该系统将存储各种用户查询,并决定是否生成包含来自向量数据库或缓存的信息的提示。

语义缓存系统旨在识别相似或相同的用户请求。当找到匹配的请求时,系统会从缓存中检索相应的信息,从而减少从原始源获取它的需要。

由于比较考虑了请求的语义含义,因此它们不必完全相同,系统就可以将它们识别为同一个问题。它们可以以不同的方式表达或包含不准确之处,无论是印刷错误还是句子结构,我们都可以确定用户实际上正在请求相同的信息。

例如,像“法国的首都是什么?”、告诉我法国首都的名字?和“法国的首都是什么?”这样的查询都传达了相同的意图,应该被识别为同一个问题。

虽然模型的响应可能因第二个示例中对简洁答案的请求而有所不同,但从向量数据库检索到的信息应该是相同的。这就是为什么我将缓存系统放在用户和向量数据库之间,而不是用户和大型语言模型之间。

大多数指导您创建 RAG 系统的教程都是为单用户使用而设计的,旨在在测试环境中运行。换句话说,在笔记本中,与本地向量数据库交互并进行 API 调用或使用本地存储的模型。

当尝试将其中一个模型转换为生产时,这种架构很快就会变得不足,因为它们可能会遇到数十到数千个重复请求。

提高性能的一种方法是通过一个或多个语义缓存。此缓存保留先前请求的结果,在解析新请求之前,它会检查之前是否收到过类似的请求。如果是,它不会重新执行该过程,而是从缓存中检索信息。

在 RAG 系统中,有两点很耗时:

  1. 检索用于构建丰富提示的信息
  2. 调用大型语言模型以获取响应

在这两个点中,都可以实现语义缓存系统,我们甚至可以有两个缓存,每个点一个。

将其放置在模型的响应点可能会导致对所获得响应的影响丧失。我们的缓存系统可以将“用 10 个字解释法国大革命”和“用 100 个字解释法国大革命”视为相同的查询。如果我们的缓存系统存储模型响应,用户可能会认为他们的指令没有被准确遵循。

但这两个请求都需要相同的信息来丰富提示。这就是我选择将语义缓存系统放置在用户请求和从矢量数据库检索信息之间的主要原因。

但是,这是一个设计决策。根据响应和系统请求的类型,它可以放在一个点或另一个点。很明显,缓存模型响应可以节省最多的时间,但正如我已经解释过的,这是以失去用户对响应的影响为代价的。

2.导入并加载库。

首先,我们需要安装必要的 Python 包。

sentence transformers.:这个库对于将句子转换为固定长度的向量(也称为嵌入)是必需的。
xformers:它是一个提供库和实用程序的包,以方便使用转换器模型。我们需要安装它以避免在使用模型和嵌入时出现错误。
chromadb:这是我们的矢量数据库。ChromaDB 易于使用且开源,可能是用于存储嵌入的最常用的矢量数据库。
accelerate:需要在 GPU 中运行模型。

!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1

import numpy as np
import pandas as pd

3.加载数据集

由于我们在一个自由且有限的空间内工作,并且只能使用几 GB 的内存,因此我使用变量 MAX_ROWS 限制了数据集中要使用的行数。

#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')
data = data.to_pandas()
data["id"]=data.index
data.head(10)
qtypeQuestionAnswerid
0 susceptibility Who is at risk for Lymphocytic Choriomeningiti… LCMV infections can occur after exposure to fr… 0
1 symptoms What are the symptoms of Lymphocytic Choriomen… LCMV is most commonly recognized as causing ne… 1
2 susceptibility Who is at risk for Lymphocytic Choriomeningiti… Individuals of all ages who come into contact … 2
3 exams and tests How to diagnose Lymphocytic Choriomeningitis (… During the first phase of the disease, the mos… 3
4 treatment What are the treatments for Lymphocytic Chorio… Aseptic meningitis, encephalitis, or meningoen… 4
5 prevention How to prevent Lymphocytic Choriomeningitis (L… LCMV infection can be prevented by avoiding co… 5
6 information What is (are) Parasites – Cysticercosis ? Cysticercosis is an infection caused by the la… 6
7 susceptibility Who is at risk for Parasites – Cysticercosis? ? Cysticercosis is an infection caused by the la… 7
8 exams and tests How to diagnose Parasites – Cysticercosis ? If you think that you may have cysticercosis, … 8
9 treatment What are the treatments for Parasites – Cystic… Some people with cysticercosis do not need to … 9
MAX_ROWS = 15000
DOCUMENT="Answer"
TOPIC="qtype"

ChromaDB 要求数据具有唯一标识符。我们可以使用此语句来实现,它将创建一个名为 Id 的新列。

#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)

4.导入和配置矢量数据库

我将使用最流行的开源矢量数据库 ChromaDB。

首先,我们需要导入 ChromaDB,然后从 chromadb.config 模块导入 Settings 类。该类允许我们更改 ChromaDB 系统的设置并自定义其行为。

import chromadb
from chromadb.config import Settings

现在我们只需要指明矢量数据库的存储路径。

chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")

5.填充和查询 ChromaDB 数据库

ChromaDB 中的数据存储在集合中。如果集合存在,我们需要删除它。

在接下来的几行中,我们将通过调用上面创建的 chroma_client 中的 create_collection 函数来创建集合。

collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
        chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

是时候将数据添加到集合中了。使用 add 函数,我们至少需要通知文档、元数据和 ID。

在文档中,我们存储大文本,它是每个数据集中的不同列。
在元数据中,我们可以通知主题列表。
在 ID 中,我们需要为每行通知一个唯一的标识符。它必须是唯一的!我正在使用 MAX_ROWS 范围创建 ID。

collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

一旦我们获得了数据库中的信息,我们就可以查询它,并请求符合我们需求的数据。搜索是在文档内容内进行的,它不会查找确切的单词或短语。结果将基于搜索词和文档内容之间的相似性。

元数据不用于搜索,但可用于在初始搜索后过滤或优化结果。

让我们定义一个函数来查询 ChromaDB 数据库。

def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results )
    return results

6. 创建语义缓存系统

为了实现缓存系统,我们将使用 Faiss,这是一个允许将嵌入存储在内存中的库。它与 Chroma 所做的非常相似,但没有持久性。

为此,我们将创建一个名为 semantic_cache 的类,它将与其自己的编码器一起工作,并为用户提供执行查询所需的功能。

在这个类中,我们首先查询 Faiss(缓存),如果返回的结果高于指定的阈值,它将从缓存中返回结果。否则,它将从 Chroma 数据库中获取结果。

缓存存储在 .json 文件中。

!pip install -q faiss-cpu==1.8.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 62.3 MB/s eta 0:00:00
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

此函数初始化语义缓存。

它采用 FlatLS 索引,这可能不是最快的,但非常适合小型数据集。根据要缓存的数据的特征和预期的数据集大小,可以使用其他索引,例如 HNSW 或 IVF。

def init_cache():
  index = faiss.IndexFlatL2(768)
  if index.is_trained:
            print('Index trained')

  # Initialize Sentence Transformer model
  encoder = SentenceTransformer('all-mpnet-base-v2')

  return index, encoder

在retrieve_cache函数中,如果需要在会话间重用缓存,则会从磁盘检索.json文件。

def retrieve_cache(json_file):
      try:
          with open(json_file, 'r') as file:
              cache = json.load(file)
      except FileNotFoundError:
          cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}

      return cache

store_cache 函数将包含缓存数据的文件保存到磁盘。

def store_cache(json_file, cache):
  with open(json_file, 'w') as file:
        json.dump(cache, file)

这些函数将在 SemanticCache 类中使用,该类包括搜索函数及其初始化函数。

尽管 ask 函数有大量代码,但其目的却非常简单。它在缓存中查找与用户刚刚提出的问题最接近的问题。

然后,检查它是否在指定的阈值内。如果是肯定的,它直接从缓存中返回响应;否则,它调用 query_database 函数从 ChromaDB 中检索数据。

我使用了欧几里得距离而不是余弦,后者在向量比较中被广泛使用。这种选择是基于欧几里得距离是 Faiss 使用的默认度量这一事实。虽然也可以计算余弦距离,但这样做会增加复杂性,可能不会对最终结果产生重大影响。

class semantic_cache:
  def __init__(self, json_file="cache_file.json", thresold=0.35):
      # Initialize Faiss index with Euclidean distance
      self.index, self.encoder = init_cache()

      # Set Euclidean distance threshold
      # a distance of 0 means identicals sentences
      # We only return from cache sentences under this thresold
      self.euclidean_threshold = thresold

      self.json_file = json_file
      self.cache = retrieve_cache(self.json_file)

  def ask(self, question: str) -> str:
      # Method to retrieve an answer from the cache or generate a new one
      start_time = time.time()
      try:
          #First we obtain the embeddings corresponding to the user question
          embedding = self.encoder.encode([question])

          # Search for the nearest neighbor in the index
          self.index.nprobe = 8
          D, I = self.index.search(embedding, 1)

          if D[0] >= 0:
              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                  row_id = int(I[0][0])

                  print('Answer recovered from Cache. ')
                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
                  print(f'response_text: ' + self.cache['response_text'][row_id])

                  end_time = time.time()
                  elapsed_time = end_time - start_time
                  print(f"Time taken: {elapsed_time:.3f} seconds")
                  return self.cache['response_text'][row_id]

          # Handle the case when there are not enough results
          # or Euclidean distance is not met, asking to chromaDB.
          answer  = query_database([question], 1)
          response_text = answer['documents'][0][0]

          self.cache['questions'].append(question)
          self.cache['embeddings'].append(embedding[0].tolist())
          self.cache['answers'].append(answer)
          self.cache['response_text'].append(response_text)

          print('Answer recovered from ChromaDB. ')
          print(f'response_text: {response_text}')

          self.index.add(embedding)
          store_cache(self.json_file, self.cache)
          end_time = time.time()
          elapsed_time = end_time - start_time
          print(f"Time taken: {elapsed_time:.3f} seconds")

          return response_text
      except Exception as e:
          raise RuntimeError(f"Error during 'ask' method: {e}")

6.1 测试semantic_cache类。

# Initialize the cache.
cache = semantic_cache('4cache_file.json')
Index trained
modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]
config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]
README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]
sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]
config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]
pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]
tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]
vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]
tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]
special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]
1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]
results = cache.ask("How work a vaccine?")
Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.655 seconds

正如预期的那样,该响应已从 ChromaDB 获得。然后该类将其存储在缓存中。

现在,如果我们发送第二个完全不同的问题,也应该从 ChromaDB 检索响应。这是因为之前存储的问题非常不同,以至于它会超出欧几里得距离的指定阈值。

results = cache.ask("Explain briefly what is a Periodic Paralyses")
Answer recovered from ChromaDB. 
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.083 seconds

完美,语义缓存系统的表现符合预期。

让我们继续用一个与我们刚刚提出的问题非常相似的问题来测试它。

在这种情况下,响应应该直接来自缓存,而无需访问 ChromaDB 数据库。

results = cache.ask("Briefly explain me what is a periodic paralyses")
Answer recovered from Cache. 
0.018 smaller than 0.35
Found cache in row: 1 with score 0.018
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.015 seconds

这两个问题非常相似,它们的欧几里得距离非常小,几乎就像是完全相同的。

现在,让我们尝试另一个问题,这次问题更加明显,并观察系统的行为。

question_def = "Write in 20 words what is a periodic paralyses"
results = cache.ask(question_def)
Answer recovered from Cache. 
0.220 smaller than 0.35
Found cache in row: 1 with score 0.220
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.017 seconds

我们观察到欧氏距离有所增加,但仍在指定的阈值内。因此,它继续直接从缓存中返回响应。

7. 加载模型并创建提示

是时候使用库 transformers 了,这是 hugging face 最著名的用于处理语言模型的库。

我们正在导入:

Autotokenizer:它是一个实用程序类,用于标记与各种预训练语言模型兼容的文本输入。

AutoModelForCasualLLM:它提供了一个预训练语言模型的接口,该模型专门为使用因果语言建模(例如 GPT 模型)的语言生成任务而设计,或者本笔记本中使用的模型 *Gemma-2b-it。

所选模型是 Gemma-2b-it。

请随意测试不同的模型,您需要搜索针对文本生成训练的 NLP 模型。

!pip install torch
from torch import cuda, torch
#In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

from transformers import AutoTokenizer, AutoModelForCausalLM

#model_id = "databricks/dolly-v2-3b"
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="cuda",
                                            torch_dtype=torch.bfloat16)
tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]
tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]
tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]
special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]
config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]
model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]
Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]
model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]
model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

8. 创建扩展提示

为了创建提示,我们使用查询“semantic_cache”类的结果和用户提出的问题。

提示有两个部分,相关上下文(即从数据库中恢复的信息)和用户的问题。

我们只需要将这两个部分放在一起即可创建提示,然后将其发送给模型。

prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template
"Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.\n                \nThe two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.\n\n The user's question: Write in 20 words what is a periodic paralyses"
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

现在剩下的就是将提示发送给模型并等待它的响应!

outputs = model.generate(**input_ids,
                         max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
<bos>Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.

 The user's question: Write in 20 words what is a periodic paralyses?

Answer: A group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells.<eos>

9. 结论

访问 ChromaDB 和直接访问缓存之间的性能提升约为 50%。然而,在较大的项目中,这种差异会增加,从而导致 90-95% 的提升。

我们在 Chroma 中的数据很少,并且只有一个缓存类实例。通常,缓存系统背后的数据要大得多,可能涉及的不仅仅是对矢量数据库的查询,而是来自不同的地方。

通常有多个缓存类实例,通常基于用户类型,因为问题往往会在具有共同特征的用户中重复更多。

总之,我们创建了一个非常简单的 RAG(检索增强生成)系统,并在用户的问题和获取创建丰富提示所需的信息之间添加了一个语义缓存层来增强它。

posted @ 2024-11-28 16:06  bonelee  阅读(12)  评论(0编辑  收藏  举报