ES RAG向量搜索示例,使用BAAI BGE创建embedding
准备:
docker pull docker.elastic.co/elasticsearch/elasticsearch:7.6.2 7.6.2: Pulling from elasticsearch/elasticsearch c808caf183b6: Pull complete d6caf8e15a64: Pull complete b0ba5f324e82: Pull complete d7e8c1e99b9a: Pull complete 85c4d6c81438: Pull complete 3119218fac98: Pull complete 914accf214bb: Pull complete Digest: sha256:59342c577e2b7082b819654d119f42514ddf47f0699c8b54dc1f0150250ce7aa Status: Downloaded newer image for docker.elastic.co/elasticsearch/elasticsearch:7.6.2 docker.elastic.co/elasticsearch/elasticsearch:7.6.2 What's Next? View a summary of image vulnerabilities and recommendations → docker scout quickview docker.elastic.co/elasticsearch/elasticsearch:7.6.2 PS D:\source\pythonProject> pip install elasticsearch Requirement already satisfied: elasticsearch in d:\python\python312\lib\site-packages (7.6.0) Requirement already satisfied: urllib3>=1.21.1 in d:\python\python312\lib\site-packages (from elasticsearch) (1.26.18)
进入容器修改配置 docker exec -it esid bash cd config/ vi elasticsearch.yml 增加 http.cors.enabled: true http.cors.allow-origin: "*" discovery.zen.minimum_master_nodes: 1 重启服务 docker restart esid
查看页面
ip:9200
编写代码:
from elasticsearch import Elasticsearch # 连接Elasticsearch es = Elasticsearch() # 定义索引的设置和映射 index_name = "vector_search_example" index_settings = { "settings": { "number_of_shards": 1, "number_of_replicas": 0 }, "mappings": { "properties": { "title": {"type": "text"}, "embedding": { "type": "dense_vector", # 使用dense_vector类型 "dims": 5, # 向量维度,根据实际情况调整 } } } } # 创建索引 if not es.indices.exists(index=index_name): es.indices.create(index=index_name, body=index_settings) # 存储向量数据示例 doc1 = { "title": "Hello World Document", "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] # 示例向量数据 } response = es.index(index=index_name, id=1, body=doc1) print(f"Indexed document: {response['result']}") # 添加更多文档 doc2 = { "title": "Another Document Example", "embedding": [0.2, 0.35, 0.45, 0.55, 0.6] # 另一个示例向量 } response = es.index(index=index_name, id=2, body=doc2) print(f"Indexed document: {response['result']}") doc3 = { "title": "Yet Another Hello", "embedding": [0.7, 0.6, 0.5, 0.4, 0.3] # 第三个示例向量,与前两个有较大差异 } response = es.index(index=index_name, id=3, body=doc3) print(f"Indexed document: {response['result']}") # 搜索相似向量 query_vector = [0.2, 0.3, 0.4, 0.5, 0.6] # 查询向量 script_query = { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": {"query_vector": query_vector} } } } response = es.search(index=index_name, body={"query": script_query}, size=2) # 打印搜索结果 for hit in response["hits"]["hits"]: print(f"Document ID: {hit['_id']}, Score: {hit['_score']}, Title: {hit['_source']['title']}")
返回结果:
Indexed document: updated Indexed document: updated Indexed document: updated Document ID: 2, Score: 1.9982954, Title: Another Document Example Document ID: 1, Score: 1.9949367, Title: Hello World Document
我们再复杂一点,使用BGE模型进行编码,便于搜索:
from elasticsearch import Elasticsearch from FlagEmbedding import FlagModel from collections import defaultdict from time import time # 连接Elasticsearch es = Elasticsearch() # 定义索引的设置和映射 index_name = "vector_search_sec_tool" index_settings = { "settings": { "number_of_shards": 1, "number_of_replicas": 0 }, "mappings": { "properties": { "description": {"type": "text"}, "embedding": { "type": "dense_vector", "dims": 768 } } } } # 创建索引 if not es.indices.exists(index=index_name): es.indices.create(index=index_name, body=index_settings) def search_sectool_knowledge_base(descriptions): # 构建索引 corpus = [] index = defaultdict(dict) for item in descriptions: for method, description in item['methods'].items(): index[description] = {"method": method, "path": item["path"]} corpus.append(description) embedder = FlagModel('bge-base-zh-v1.5/', query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",) corpus_embeddings = embedder.encode(corpus) # 存储向量数据到Elasticsearch for i, description in enumerate(corpus): doc = { "description": description, "embedding": corpus_embeddings[i].tolist() # 将numpy数组转换为列表 } response = es.index(index=index_name, id=i+1, body=doc) print(f"Indexed document: {response['result']}") # Query sentences: queries = [ '搜索告警列表', '查询漏洞'] now = time() times = 1 for i in range(times): for query in queries: query_embedding = embedder.encode(query).tolist() script_query = { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": {"query_vector": query_embedding} } } } response = es.search(index=index_name, body={"query": script_query}, size=3) print("\n\n======================\n\n") print("Query:", query) print("\nTop 3 most similar sentences in corpus:") for hit in response["hits"]["hits"]: description = hit["_source"]["description"] score = hit["_score"] print(f"{description} (Score: {score:.4f}) ==> {index[description]}") print(f"{times} {time() - now} seconds elapsed") if __name__ == '__main__': descriptions = [{'path': '/v1/{project_id}/subscriptions/version', 'methods': {'GET': '获取视图订购信息'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/sa/reports', 'methods': {'GET': '分析报管理获取报告列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules', 'methods': {'GET': 'corss-workspace智能建模聚合列表接口'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules/metrics', 'methods': {'GET': 'cross-workspace智能建模可用模型指标接口'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search', 'methods': {'POST': '搜索告警列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search', 'methods': {'POST': '搜索事件列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search', 'methods': {'POST': '威胁情报列表查询'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search', 'methods': {'POST': '查询漏洞列表'}}] search_sectool_knowledge_base(descriptions)
运行结果:
Indexed document: updated Indexed document: updated Indexed document: updated Indexed document: updated Indexed document: updated Indexed document: updated Indexed document: updated Indexed document: updated ====================== Query: 搜索告警列表 Top 3 most similar sentences in corpus: 搜索告警列表 (Score: 2.0000) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search'} 搜索事件列表 (Score: 1.9030) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search'} 威胁情报列表查询 (Score: 1.8769) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search'} ====================== Query: 查询漏洞 Top 3 most similar sentences in corpus: 查询漏洞列表 (Score: 1.9688) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search'} 威胁情报列表查询 (Score: 1.8580) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search'} 搜索告警列表 (Score: 1.8370) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search'} 1 0.060091257095336914 seconds elapsed
还可以继续优化下,将ES数据存储完整:
from elasticsearch import Elasticsearch from FlagEmbedding import FlagModel from collections import defaultdict from time import time # 连接Elasticsearch es = Elasticsearch() # 定义索引的设置和映射 index_name = "vector_search_example222" index_settings = { "settings": { "number_of_shards": 1, "number_of_replicas": 0 }, "mappings": { "properties": { "path": {"type": "text"}, "methods": {"type": "object"}, "description": {"type": "text"}, "embedding": { "type": "dense_vector", "dims": 768 # 假设FlagModel生成768维的向量 } } } } # 创建索引 if not es.indices.exists(index=index_name): es.indices.create(index=index_name, body=index_settings) def search_sec_knowledge_base(descriptions): # 构建索引 corpus = [] index = defaultdict(dict) for item in descriptions: for method, description in item['methods'].items(): index[description] = {"method": method, "path": item["path"]} corpus.append(description) model = FlagModel('bge-base-zh-v1.5/', query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",) embedder = model corpus_embeddings = embedder.encode(corpus) # 存储向量数据到Elasticsearch for i, description in enumerate(corpus): doc = { "path": index[description]["path"], "methods": {index[description]["method"]: description}, "description": description, "embedding": [float(x) for x in corpus_embeddings[i]] # 确保是浮点数列表 } response = es.index(index=index_name, id=i+1, body=doc) print(f"Indexed document: {response['result']}") # Query sentences: queries = [ '搜索告警列表', '查询漏洞', 'Someone in a gorilla costume is playing a set of drums.', 'A cheetah chases prey on across a field.'] now = time() times = 1 for i in range(times): for query in queries: query_embedding = [float(x) for x in embedder.encode(query)] script_query = { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": {"query_vector": query_embedding} } } } response = es.search(index=index_name, body={"query": script_query}, size=5) print("\n\n======================\n\n") print("Query:", query) print("\nTop 5 most similar sentences in corpus:") for hit in response["hits"]["hits"]: source = hit["_source"] description = source["description"] score = hit["_score"] print(f"{description} (Score: {score:.4f}) ==> Path: {source['path']}, Methods: {source['methods']}") print(f"{times} {time() - now} seconds elapsed") if __name__ == '__main__': descriptions = [{'path': '/v1/{project_id}/subscriptions/version', 'methods': {'GET': '获取视图订购信息'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/sa/reports', 'methods': {'GET': '分析报管理获取报告列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules', 'methods': {'GET': 'corss-workspace智能建模聚合列表接口'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules/metrics', 'methods': {'GET': 'cross-workspace智能建模可用模型指标接口'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search', 'methods': {'POST': '搜索告警列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search', 'methods': {'POST': '搜索事件列表'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search', 'methods': {'POST': '威胁情报列表查询'}}, {'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search', 'methods': {'POST': '查询漏洞列表'}}] search_sec_knowledge_base(descriptions)