graphrag api调用

"""
参考:https://microsoft.github.io/graphrag/posts/get_started/
1. 初始化家目录:python -m graphrag.index --init --root ./ragtest
2. 初始化索引:python -m graphrag.index --root ./ragtest

脚本需要放置在ragtest目录下运行
"""

import os
import re
from pathlib import Path
from typing import cast, Union, Tuple

import pandas as pd

from graphrag.config import (
    GraphRagConfig,
    create_graphrag_config,
)
from graphrag.index.progress import PrintProgressReporter
from graphrag.query.input.loaders.dfs import (
    store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.query.factories import get_local_search_engine
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)

reporter = PrintProgressReporter("")


class LocalSearchEngine:

    """
    根据官方代码适当调整:代码启动加载search_agent避免重复加载,对外仅暴露一个调用接口
    response_type 返回: Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report
    """
    def __init__(self, data_dir: Union[str, None], root_dir: Union[str, None]):
        self.data_dir, self.root_dir, self.config = self._configure_paths_and_settings(
            data_dir, root_dir
        )
        self.description_embedding_store = self._get_embedding_description_store()
        self.agent = self.search_agent(
            community_level=2, response_type="Single Paragraph"
        )

    def _configure_paths_and_settings(
        self, data_dir: Union[str, None], root_dir: Union[str, None]
    ) -> Tuple[str, Union[str, None], GraphRagConfig]:
        if data_dir is None and root_dir is None:
            msg = "Either data_dir or root_dir must be provided."
            raise ValueError(msg)
        if data_dir is None:
            data_dir = self._infer_data_dir(cast(str, root_dir))
        config = self._create_graphrag_config(root_dir, data_dir)
        return data_dir, root_dir, config

    @staticmethod
    def _infer_data_dir(root: str) -> str:
        output = Path(root) / "output"
        if output.exists():
            folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True)
            if folders:
                folder = folders[0]
                return str((folder / "artifacts").absolute())
        msg = f"Could not infer data directory from root={root}"
        raise ValueError(msg)

    def _create_graphrag_config(
        self, root: Union[str, None], data_dir: Union[str, None]
    ) -> GraphRagConfig:
        return self._read_config_parameters(cast(str, root or data_dir))

    @staticmethod
    def _read_config_parameters(root: str) -> GraphRagConfig:
        _root = Path(root)
        settings_yaml = _root / "settings.yaml"
        if not settings_yaml.exists():
            settings_yaml = _root / "settings.yml"
        settings_json = _root / "settings.json"

        if settings_yaml.exists():
            reporter.info(f"Reading settings from {settings_yaml}")
            with settings_yaml.open("rb") as file:
                import yaml

                data = yaml.safe_load(
                    file.read().decode(encoding="utf-8", errors="strict")
                )
                return create_graphrag_config(data, root)

        if settings_json.exists():
            reporter.info(f"Reading settings from {settings_json}")
            with settings_json.open("rb") as file:
                import json

                data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
                return create_graphrag_config(data, root)

        reporter.info("Reading settings from environment variables")
        return create_graphrag_config(root_dir=root)

    @staticmethod
    def _get_embedding_description_store(
        vector_store_type: str = VectorStoreType.LanceDB, config_args: dict = None
    ):
        if not config_args:
            config_args = {}

        config_args.update(
            {
                "collection_name": config_args.get(
                    "query_collection_name",
                    config_args.get("collection_name", "description_embedding"),
                ),
            }
        )

        description_embedding_store = VectorStoreFactory.get_vector_store(
            vector_store_type=vector_store_type, kwargs=config_args
        )

        description_embedding_store.connect(**config_args)
        return description_embedding_store

    def search_agent(self, community_level: int, response_type: str):
        """获取搜索引擎"""
        data_path = Path(self.data_dir)

        final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
        final_community_reports = pd.read_parquet(
            data_path / "create_final_community_reports.parquet"
        )
        final_text_units = pd.read_parquet(
            data_path / "create_final_text_units.parquet"
        )
        final_relationships = pd.read_parquet(
            data_path / "create_final_relationships.parquet"
        )
        final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
        final_covariates_path = data_path / "create_final_covariates.parquet"
        final_covariates = (
            pd.read_parquet(final_covariates_path)
            if final_covariates_path.exists()
            else None
        )

        vector_store_args = (
            self.config.embeddings.vector_store
            if self.config.embeddings.vector_store
            else {}
        )
        vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

        description_embedding_store = self._get_embedding_description_store(
            vector_store_type=vector_store_type,
            config_args=vector_store_args,
        )
        entities = read_indexer_entities(final_nodes, final_entities, community_level)
        store_entity_semantic_embeddings(
            entities=entities, vectorstore=description_embedding_store
        )
        covariates = (
            read_indexer_covariates(final_covariates)
            if final_covariates is not None
            else []
        )

        return get_local_search_engine(
            self.config,
            reports=read_indexer_reports(
                final_community_reports, final_nodes, community_level
            ),
            text_units=read_indexer_text_units(final_text_units),
            entities=entities,
            relationships=read_indexer_relationships(final_relationships),
            covariates={"claims": covariates},
            description_embedding_store=description_embedding_store,
            response_type=response_type,
        )

    def run_search(self, query: str):
        """
        搜索入口
        :param query: 问题
        :return:
        """
        result = self.agent.search(query=query)
        return self.remove_sources(result.response)

    @staticmethod
    def remove_sources(text):
        """
        使用正则表达式匹配 [Data: Sources (82, 14, 42, 98)] 这种格式的字符串
        :param text:
        :return:
        """
        cleaned_text = re.sub(r'\[Data: [^]]+\]', '', text)
        return cleaned_text


# Example usage
BASEDIR = os.path.dirname(__file__)  # Set your base directory path here

local_search_engine = LocalSearchEngine(data_dir=None, root_dir=BASEDIR)
if __name__ == '__main__':
    local_res = local_search_engine.run_search(
        query="如何添加设备",
    )
    print(local_res)

搜索方式有global跟loca两种。如果想通过api调用global,修改几个关键字就行。

posted @ 2024-07-30 16:43  一石数字欠我15w!!!  阅读(51)  评论(0编辑  收藏  举报