graphrag api调用
""" 参考:https://microsoft.github.io/graphrag/posts/get_started/ 1. 初始化家目录:python -m graphrag.index --init --root ./ragtest 2. 初始化索引:python -m graphrag.index --root ./ragtest 脚本需要放置在ragtest目录下运行 """ import os import re from pathlib import Path from typing import cast, Union, Tuple import pandas as pd from graphrag.config import ( GraphRagConfig, create_graphrag_config, ) from graphrag.index.progress import PrintProgressReporter from graphrag.query.input.loaders.dfs import ( store_entity_semantic_embeddings, ) from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.query.factories import get_local_search_engine from graphrag.query.indexer_adapters import ( read_indexer_covariates, read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, ) reporter = PrintProgressReporter("") class LocalSearchEngine: """ 根据官方代码适当调整:代码启动加载search_agent避免重复加载,对外仅暴露一个调用接口 response_type 返回: Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report """ def __init__(self, data_dir: Union[str, None], root_dir: Union[str, None]): self.data_dir, self.root_dir, self.config = self._configure_paths_and_settings( data_dir, root_dir ) self.description_embedding_store = self._get_embedding_description_store() self.agent = self.search_agent( community_level=2, response_type="Single Paragraph" ) def _configure_paths_and_settings( self, data_dir: Union[str, None], root_dir: Union[str, None] ) -> Tuple[str, Union[str, None], GraphRagConfig]: if data_dir is None and root_dir is None: msg = "Either data_dir or root_dir must be provided." raise ValueError(msg) if data_dir is None: data_dir = self._infer_data_dir(cast(str, root_dir)) config = self._create_graphrag_config(root_dir, data_dir) return data_dir, root_dir, config @staticmethod def _infer_data_dir(root: str) -> str: output = Path(root) / "output" if output.exists(): folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) if folders: folder = folders[0] return str((folder / "artifacts").absolute()) msg = f"Could not infer data directory from root={root}" raise ValueError(msg) def _create_graphrag_config( self, root: Union[str, None], data_dir: Union[str, None] ) -> GraphRagConfig: return self._read_config_parameters(cast(str, root or data_dir)) @staticmethod def _read_config_parameters(root: str) -> GraphRagConfig: _root = Path(root) settings_yaml = _root / "settings.yaml" if not settings_yaml.exists(): settings_yaml = _root / "settings.yml" settings_json = _root / "settings.json" if settings_yaml.exists(): reporter.info(f"Reading settings from {settings_yaml}") with settings_yaml.open("rb") as file: import yaml data = yaml.safe_load( file.read().decode(encoding="utf-8", errors="strict") ) return create_graphrag_config(data, root) if settings_json.exists(): reporter.info(f"Reading settings from {settings_json}") with settings_json.open("rb") as file: import json data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) return create_graphrag_config(data, root) reporter.info("Reading settings from environment variables") return create_graphrag_config(root_dir=root) @staticmethod def _get_embedding_description_store( vector_store_type: str = VectorStoreType.LanceDB, config_args: dict = None ): if not config_args: config_args = {} config_args.update( { "collection_name": config_args.get( "query_collection_name", config_args.get("collection_name", "description_embedding"), ), } ) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args ) description_embedding_store.connect(**config_args) return description_embedding_store def search_agent(self, community_level: int, response_type: str): """获取搜索引擎""" data_path = Path(self.data_dir) final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") final_community_reports = pd.read_parquet( data_path / "create_final_community_reports.parquet" ) final_text_units = pd.read_parquet( data_path / "create_final_text_units.parquet" ) final_relationships = pd.read_parquet( data_path / "create_final_relationships.parquet" ) final_entities = pd.read_parquet(data_path / "create_final_entities.parquet") final_covariates_path = data_path / "create_final_covariates.parquet" final_covariates = ( pd.read_parquet(final_covariates_path) if final_covariates_path.exists() else None ) vector_store_args = ( self.config.embeddings.vector_store if self.config.embeddings.vector_store else {} ) vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) description_embedding_store = self._get_embedding_description_store( vector_store_type=vector_store_type, config_args=vector_store_args, ) entities = read_indexer_entities(final_nodes, final_entities, community_level) store_entity_semantic_embeddings( entities=entities, vectorstore=description_embedding_store ) covariates = ( read_indexer_covariates(final_covariates) if final_covariates is not None else [] ) return get_local_search_engine( self.config, reports=read_indexer_reports( final_community_reports, final_nodes, community_level ), text_units=read_indexer_text_units(final_text_units), entities=entities, relationships=read_indexer_relationships(final_relationships), covariates={"claims": covariates}, description_embedding_store=description_embedding_store, response_type=response_type, ) def run_search(self, query: str): """ 搜索入口 :param query: 问题 :return: """ result = self.agent.search(query=query) return self.remove_sources(result.response) @staticmethod def remove_sources(text): """ 使用正则表达式匹配 [Data: Sources (82, 14, 42, 98)] 这种格式的字符串 :param text: :return: """ cleaned_text = re.sub(r'\[Data: [^]]+\]', '', text) return cleaned_text # Example usage BASEDIR = os.path.dirname(__file__) # Set your base directory path here local_search_engine = LocalSearchEngine(data_dir=None, root_dir=BASEDIR) if __name__ == '__main__': local_res = local_search_engine.run_search( query="如何添加设备", ) print(local_res)
搜索方式有global跟loca两种。如果想通过api调用global,修改几个关键字就行。
本文来自博客园,作者:一石数字欠我15w!!!,转载请注明原文链接:https://www.cnblogs.com/52-qq/p/18332835