superAGI核心源码分析

main.py

import requests
from fastapi import FastAPI, HTTPException, Depends, Request, status, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
from fastapi_sqlalchemy import DBSessionMiddleware, db
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

import superagi
from datetime import timedelta, datetime
from superagi.agent.workflow_seed import IterationWorkflowSeed, AgentWorkflowSeed
from superagi.config.config import get_config
from superagi.controllers.agent import router as agent_router
from superagi.controllers.agent_execution import router as agent_execution_router
from superagi.controllers.agent_execution_feed import router as agent_execution_feed_router
from superagi.controllers.agent_execution_permission import router as agent_execution_permission_router
from superagi.controllers.agent_template import router as agent_template_router
from superagi.controllers.agent_workflow import router as agent_workflow_router
from superagi.controllers.budget import router as budget_router
from superagi.controllers.config import router as config_router
from superagi.controllers.organisation import router as organisation_router
from superagi.controllers.project import router as project_router
from superagi.controllers.twitter_oauth import router as twitter_oauth_router
from superagi.controllers.google_oauth import router as google_oauth_router
from superagi.controllers.resources import router as resources_router
from superagi.controllers.tool import router as tool_router
from superagi.controllers.tool_config import router as tool_config_router
from superagi.controllers.toolkit import router as toolkit_router
from superagi.controllers.user import router as user_router
from superagi.controllers.agent_execution_config import router as agent_execution_config
from superagi.controllers.analytics import router as analytics_router
from superagi.controllers.models_controller import router as models_controller_router
from superagi.controllers.knowledges import router as knowledges_router
from superagi.controllers.knowledge_configs import router as knowledge_configs_router
from superagi.controllers.vector_dbs import router as vector_dbs_router
from superagi.controllers.vector_db_indices import router as vector_db_indices_router
from superagi.controllers.marketplace_stats import router as marketplace_stats_router
from superagi.controllers.api_key import router as api_key_router
from superagi.controllers.api.agent import router as api_agent_router
from superagi.controllers.webhook import router as web_hook_router
from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.llm_model_factory import build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.agent_template import AgentTemplate
from superagi.models.models_config import ModelsConfig
from superagi.models.organisation import Organisation
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from superagi.models.user import User
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
from urllib.parse import urlparse
app = FastAPI()

db_host = get_config('DB_HOST', 'super__postgres')
db_url = get_config('DB_URL', None)
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
env = get_config('ENV', "DEV")

if db_url is None:
    if db_username is None:
        db_url = f'postgresql://{db_host}/{db_name}'
    else:
        db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
    db_url = urlparse(db_url)
    db_url = db_url.scheme + "://" + db_url.netloc + db_url.path

engine = create_engine(db_url,
                       pool_size=20,  # Maximum number of database connections in the pool
                       max_overflow=50,  # Maximum number of connections that can be created beyond the pool_size
                       pool_timeout=30,  # Timeout value in seconds for acquiring a connection from the pool
                       pool_recycle=1800,  # Recycle connections after this number of seconds (optional)
                       pool_pre_ping=False,  # Enable connection health checks (optional)
                       )

# app.add_middleware(DBSessionMiddleware, db_url=f'postgresql://{db_username}:{db_password}@localhost/{db_name}')
app.add_middleware(DBSessionMiddleware, db_url=db_url)

# Configure CORS middleware
origins = [
    # Add more origins if needed
    "*",  # Allow all origins
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Creating requrired tables -- Now handled using migrations
# DBBaseModel.metadata.create_all(bind=engine, checkfirst=True)
# DBBaseModel.metadata.drop_all(bind=engine,checkfirst=True)


app.include_router(user_router, prefix="/users")
app.include_router(tool_router, prefix="/tools")
app.include_router(organisation_router, prefix="/organisations")
app.include_router(project_router, prefix="/projects")
app.include_router(budget_router, prefix="/budgets")
app.include_router(agent_router, prefix="/agents")
app.include_router(agent_execution_router, prefix="/agentexecutions")
app.include_router(agent_execution_feed_router, prefix="/agentexecutionfeeds")
app.include_router(agent_execution_permission_router, prefix="/agentexecutionpermissions")
app.include_router(resources_router, prefix="/resources")
app.include_router(config_router, prefix="/configs")
app.include_router(toolkit_router, prefix="/toolkits")
app.include_router(tool_config_router, prefix="/tool_configs")
app.include_router(config_router, prefix="/configs")
app.include_router(agent_template_router, prefix="/agent_templates")
app.include_router(agent_workflow_router, prefix="/agent_workflows")
app.include_router(twitter_oauth_router, prefix="/twitter")
app.include_router(agent_execution_config, prefix="/agent_executions_configs")
app.include_router(analytics_router, prefix="/analytics")
app.include_router(models_controller_router, prefix="/models_controller")
app.include_router(google_oauth_router, prefix="/google")
app.include_router(knowledges_router, prefix="/knowledges")
app.include_router(knowledge_configs_router, prefix="/knowledge_configs")
app.include_router(vector_dbs_router, prefix="/vector_dbs")
app.include_router(vector_db_indices_router, prefix="/vector_db_indices")
app.include_router(marketplace_stats_router, prefix="/marketplace")
app.include_router(api_key_router, prefix="/api-keys")
app.include_router(api_agent_router,prefix="/v1/agent")
app.include_router(web_hook_router,prefix="/webhook")

# in production you can use Settings management
# from pydantic to get secret key from .env
class Settings(BaseModel):
    # jwt_secret = get_config("JWT_SECRET_KEY")
    authjwt_secret_key: str = superagi.config.config.get_config("JWT_SECRET_KEY")


def create_access_token(email, Authorize: AuthJWT = Depends()):
    expiry_time_hours = superagi.config.config.get_config("JWT_EXPIRY")
    if type(expiry_time_hours) == str:
        expiry_time_hours = int(expiry_time_hours)
    if expiry_time_hours is None:
        expiry_time_hours = 200
    expires = timedelta(hours=expiry_time_hours)
    access_token = Authorize.create_access_token(subject=email, expires_time=expires)
    return access_token


# callback to get your configuration
@AuthJWT.load_config
def get_config():
    return Settings()


# exception handler for authjwt
# in production, you can tweak performance using orjson response
@app.exception_handler(AuthJWTException)
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
    return JSONResponse(
        status_code=exc.status_code,
        content={"detail": exc.message}
    )


def replace_old_iteration_workflows(session):
    templates = session.query(AgentTemplate).all()
    for template in templates:
        iter_workflow = IterationWorkflow.find_by_id(session, template.agent_workflow_id)
        if not iter_workflow:
            continue
        if iter_workflow.name == "Fixed Task Queue":
            agent_workflow = AgentWorkflow.find_by_name(session, "Fixed Task Workflow")
            template.agent_workflow_id = agent_workflow.id
            session.commit()

        if iter_workflow.name == "Maintain Task Queue":
            agent_workflow = AgentWorkflow.find_by_name(session, "Dynamic Task Workflow")
            template.agent_workflow_id = agent_workflow.id
            session.commit()

        if iter_workflow.name == "Don't Maintain Task Queue" or iter_workflow.name == "Goal Based Agent":
            agent_workflow = AgentWorkflow.find_by_name(session, "Goal Based Workflow")
            template.agent_workflow_id = agent_workflow.id
            session.commit()

@app.on_event("startup")
async def startup_event():
    # Perform startup tasks here
    logger.info("Running Startup tasks")
    Session = sessionmaker(bind=engine)
    session = Session()
    default_user = session.query(User).filter(User.email == "super6@agi.com").first()
    logger.info(default_user)
    if default_user is not None:
        organisation = session.query(Organisation).filter_by(id=default_user.organisation_id).first()
        logger.info(organisation)
        register_toolkits(session, organisation)

    def register_toolkit_for_all_organisation():
        organizations = session.query(Organisation).all()
        for organization in organizations:
            register_toolkits(session, organization)
        logger.info("Successfully registered local toolkits for all Organisations!")

    def register_toolkit_for_master_organisation():
        marketplace_organisation_id = superagi.config.config.get_config("MARKETPLACE_ORGANISATION_ID")
        marketplace_organisation = session.query(Organisation).filter(
            Organisation.id == marketplace_organisation_id).first()
        if marketplace_organisation is not None:
            register_marketplace_toolkits(session, marketplace_organisation)

    IterationWorkflowSeed.build_single_step_agent(session)
    IterationWorkflowSeed.build_task_based_agents(session)
    IterationWorkflowSeed.build_action_based_agents(session)
    IterationWorkflowSeed.build_initialize_task_workflow(session)

    AgentWorkflowSeed.build_goal_based_agent(session)
    AgentWorkflowSeed.build_task_based_agent(session)
    AgentWorkflowSeed.build_fixed_task_based_agent(session)
    AgentWorkflowSeed.build_sales_workflow(session)
    AgentWorkflowSeed.build_recruitment_workflow(session)
    AgentWorkflowSeed.build_coding_workflow(session)

    # NOTE: remove old workflows. Need to remove this changes later
    workflows = ["Sales Engagement Workflow", "Recruitment Workflow", "SuperCoder", "Goal Based Workflow",
     "Dynamic Task Workflow", "Fixed Task Workflow"]
    workflows = session.query(AgentWorkflow).filter(AgentWorkflow.name.not_in(workflows))
    for workflow in workflows:
        session.delete(workflow)

    # AgentWorkflowSeed.doc_search_and_code(session)
    # AgentWorkflowSeed.build_research_email_workflow(session)
    replace_old_iteration_workflows(session)

    if env != "PROD":
        register_toolkit_for_all_organisation()
    else:
        register_toolkit_for_master_organisation()
    session.close()


@app.post('/login')
def login(request: LoginRequest, Authorize: AuthJWT = Depends()):
    """Login API for email and password based login"""

    email_to_find = request.email
    user: User = db.session.query(User).filter(User.email == email_to_find).first()

    if user == None or request.email != user.email or request.password != user.password:
        raise HTTPException(status_code=401, detail="Bad username or password")

    # subject identifier for who this token is for example id or username from database
    access_token = create_access_token(user.email, Authorize)
    return {"access_token": access_token}


# def get_jwt_from_payload(user_email: str,Authorize: AuthJWT = Depends()):
#     access_token = Authorize.create_access_token(subject=user_email)
#     return access_token

@app.get('/github-login')
def github_login():
    """GitHub login"""

    github_client_id = ""
    return RedirectResponse(f'https://github.com/login/oauth/authorize?scope=user:email&client_id={github_client_id}')


@app.get('/github-auth')
def github_auth_handler(code: str = Query(...), Authorize: AuthJWT = Depends()):
    """GitHub login callback"""

    github_token_url = 'https://github.com/login/oauth/access_token'
    github_client_id = superagi.config.config.get_config("GITHUB_CLIENT_ID")
    github_client_secret = superagi.config.config.get_config("GITHUB_CLIENT_SECRET")

    frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
    params = {
        'client_id': github_client_id,
        'client_secret': github_client_secret,
        'code': code
    }
    headers = {
        'Accept': 'application/json'
    }
    response = requests.post(github_token_url, params=params, headers=headers)
    if response.ok:
        data = response.json()
        access_token = data.get('access_token')
        github_api_url = 'https://api.github.com/user'
        headers = {
            'Authorization': f'Bearer {access_token}'
        }
        response = requests.get(github_api_url, headers=headers)
        if response.ok:
            user_data = response.json()
            user_email = user_data["email"]
            if user_email is None:
                user_email = user_data["login"] + "@github.com"
            db_user: User = db.session.query(User).filter(User.email == user_email).first()
            if db_user is not None:
                jwt_token = create_access_token(user_email, Authorize)
                redirect_url_success = f"{frontend_url}?access_token={jwt_token}&first_time_login={False}"
                return RedirectResponse(url=redirect_url_success)

            user = User(name=user_data["name"], email=user_email)
            db.session.add(user)
            db.session.commit()
            jwt_token = create_access_token(user_email, Authorize)
            redirect_url_success = f"{frontend_url}?access_token={jwt_token}&first_time_login={True}"
            return RedirectResponse(url=redirect_url_success)
        else:
            redirect_url_failure = "https://superagi.com/"
            return RedirectResponse(url=redirect_url_failure)
    else:
        redirect_url_failure = "https://superagi.com/"
        return RedirectResponse(url=redirect_url_failure)


@app.get('/user')
def user(Authorize: AuthJWT = Depends()):
    """API to get current logged in User"""

    Authorize.jwt_required()
    current_user = Authorize.get_jwt_subject()
    return {"user": current_user}


@app.get("/validate-access-token")
async def root(Authorize: AuthJWT = Depends()):
    """API to validate access token"""

    try:
        Authorize.jwt_required()
        current_user_email = Authorize.get_jwt_subject()
        current_user = db.session.query(User).filter(User.email == current_user_email).first()
        return current_user
    except:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")


@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
    """API to validate LLM API Key"""
    source = request.model_source
    api_key = request.model_api_key
    model = build_model_with_api_key(source, api_key)
    valid_api_key = model.verify_access_key() if model is not None else False
    if valid_api_key:
        return {"message": "Valid API Key", "status": "success"}
    else:
        return {"message": "Invalid API Key", "status": "failed"}


@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
    """API to validate Open AI Key"""

    try:
        llm = OpenAi(api_key=open_ai_key)
        response = llm.chat_completion([{"role": "system", "content": "Hey!"}])
    except:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")


# #Unprotected route
@app.get("/hello/{name}")
async def say_hello(name: str, Authorize: AuthJWT = Depends()):
    Authorize.jwt_required()
    return {"message": f"Hello {name}"}

@app.get('/get/github_client_id')
def github_client_id():
    """Get GitHub Client ID"""

    git_hub_client_id = superagi.config.config.get_config("GITHUB_CLIENT_ID")
    if git_hub_client_id:
        git_hub_client_id = git_hub_client_id.strip()
    return {"github_client_id": git_hub_client_id}

# # __________________TO RUN____________________________
# # uvicorn main:app --host 0.0.0.0 --port 8001 --reload

  

uvicorn main:app --host 0.0.0.0 --port 8001 --reload 是什么意思?有什么作用?

这条命令是用来启动 FastAPI 应用程序的,具体使用了 Uvicorn 作为 ASGI 服务器。命令的各部分含义如下:

  • uvicorn main:app: 这部分指定了 Uvicorn 服务器要运行的 ASGI 应用。main 是 Python 文件名(不包括 .py 扩展名),app 是在该文件中定义的 FastAPI 应用实例的变量名。
  • --host 0.0.0.0: 这个选项指定服务器监听的主机地址。0.0.0.0 表示接受所有公网和私网的 IP 地址访问,使得服务器可以从任何 IP 地址接收请求。
  • --port 8001: 指定服务器监听的端口号为 8001。客户端需要通过这个端口与服务器通信。
  • --reload: 这个选项使得服务器在代码发生变化时自动重启,非常适合开发过程中使用,因为它可以即时反映代码的修改而无需手动重启服务器。

总的来说,这条命令的作用是在本地或服务器上启动一个 FastAPI 应用,监听所有 IP 地址的 8001 端口,并且开启自动重载功能,以便于开发过程中的调试和测试。

 

整个代码功能分析:

这份代码是一个使用 FastAPI 框架构建的 Web 应用程序,主要功能和关键点如下:

  1. 数据库配置:

    • 使用 SQLAlchemy 创建数据库引擎。
    • 配置数据库连接,包括连接池大小、超时时间等。
  2. 中间件配置:

    • 添加了数据库会话中间件 DBSessionMiddleware。
    • 配置了跨源资源共享(CORS)中间件,允许所有来源的请求。
  3. 路由和控制器:

    • 应用程序包含多个路由器,每个路由器管理不同的功能模块,如用户、工具、组织、项目等。
    • 例如,user_router 处理与用户相关的请求,tool_router 处理与工具相关的请求。
  4. 身份验证和授权:

    • 使用 AuthJWT 处理 JWT 相关的操作,如创建访问令牌、加载配置等。
    • 提供了登录接口,验证用户的邮箱和密码,成功后返回 JWT 访问令牌。
  5. GitHub OAuth:

    • 提供了 GitHub 登录和认证的接口,处理 OAuth 流程。
  6. API 键验证:

    • 提供了接口来验证不同的 API 键,如 LLM API 键和 OpenAI 键。
  7. 启动事件:

    • 在应用启动时执行一系列初始化任务,如注册工具包、替换旧的工作流等。
  8. 异常处理:

    • 配置了针对 AuthJWTException 的异常处理器,用于处理 JWT 认证过程中的异常。
  9. 其他 API:

    • 提供了其他一些实用的 API,如获取当前登录用户、验证访问令牌的有效性等。

这份代码展示了一个复杂的 FastAPI 应用的典型结构,包括数据库操作、路由管理、身份验证和第三方服务集成等多个方面。

 superAGI主界面:

 

以下是应用程序中定义的一些主要路由和控制器的功能及其对应的 URL 前缀:

1. 用户管理 1:

  • 功能:管理用户的创建、查询、更新和删除。
  • URL 前缀:/users

2. 工具管理 2:

  • 功能:管理工具的添加、查询、更新和删除。
  • URL 前缀:/tools

3. 组织管理 3:

  • 功能:管理组织的创建、查询、更新和删除。
  • URL 前缀:/organisations

4. 项目管理 4:

  • 功能:管理项目的创建、查询、更新和删除。
  • URL 前缀:/projects

5. 预算管理 5:

  • 功能:管理预算的创建、查询、更新和删除。
  • URL 前缀:/budgets

6. 代理管理 6:

  • 功能:管理代理的创建、查询、更新和删除。
  • URL 前缀:/agents

7. 代理执行 7:

  • 功能:管理代理执行的创建、查询、更新和删除。
  • URL 前缀:/agentexecutions
  • 代理执行反馈 8:
  • 功能:管理代理执行过程中的反馈信息。
  • URL 前缀:/agentexecutionfeeds

9. 代理执行权限 9:

  • 功能:管理代理执行的权限设置。
  • URL 前缀:/agentexecutionpermissions

10. 资源管理 10:

  • 功能:管理资源的添加、查询、更新和删除。
  • URL 前缀:/resources

11. 配置管理 11:

  • 功能:管理配置的添加、查询、更新和删除。
  • URL 前缀:/configs

12. 工具包管理 12:

  • 功能:管理工具包的添加、查询、更新和删除。
  • URL 前缀:/toolkits

13. 工具配置 13:

  • 功能:管理工具的配置信息。
  • URL 前缀:/tool_configs

14. 代理模板 14:

  • 功能:管理代理模板的创建、查询、更新和删除。
  • URL 前缀:/agent_templates

15. 代理工作流 15:

  • 功能:管理代理工作流的创建、查询、更新和删除。
  • URL 前缀:/agent_workflows

16. Twitter OAuth 16:

  • 功能:处理 Twitter OAuth 认证流程。
  • URL 前缀:/twitter

17. Google OAuth 17:

  • 功能:处理 Google OAuth 认证流程。
  • URL 前缀:/google

 

接下来分析下核心功能,首先分析下/router tools工具 管理的代码实现:

from datetime import datetime

from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel

from superagi.helper.auth import check_auth, get_user_organisation
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit

router = APIRouter()


class ToolOut(BaseModel):
    id: int
    name: str
    folder_name: str
    class_name: str
    file_name: str
    created_at: datetime
    updated_at: datetime

    class Config:
        orm_mode = True


class ToolIn(BaseModel):
    name: str
    folder_name: str
    class_name: str
    file_name: str

    class Config:
        orm_mode = True

# CRUD Operations
@router.post("/add", response_model=ToolOut, status_code=201)
def create_tool(
        tool: ToolIn,
        Authorize: AuthJWT = Depends(check_auth),
):
    """
    Create a new tool.

    Args:
        tool (ToolIn): Tool data.

    Returns:
        Tool: The created tool.

    Raises:
        HTTPException (status_code=400): If there is an issue creating the tool.

    """

    db_tool = Tool(
        name=tool.name,
        folder_name=tool.folder_name,
        class_name=tool.class_name,
        file_name=tool.file_name,
    )
    db.session.add(db_tool)
    db.session.commit()
    return db_tool


@router.get("/get/{tool_id}", response_model=ToolOut)
def get_tool(
        tool_id: int,
        Authorize: AuthJWT = Depends(check_auth),
):
    """
    Get a particular tool details.

    Args:
        tool_id (int): ID of the tool.

    Returns:
        Tool: The tool details.

    Raises:
        HTTPException (status_code=404): If the tool with the specified ID is not found.

    """

    db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
    if not db_tool:
        raise HTTPException(status_code=404, detail="Tool not found")
    return db_tool


@router.get("/list")
def get_tools(
        organisation: Organisation = Depends(get_user_organisation)):
    """Get all tools"""
    toolkits = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
    tools = []
    for toolkit in toolkits:
        db_tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
        tools.extend(db_tools)
    return tools


@router.put("/update/{tool_id}", response_model=ToolOut)
def update_tool(
        tool_id: int,
        tool: ToolIn,
        Authorize: AuthJWT = Depends(check_auth),
):
    """
    Update a particular tool.

    Args:
        tool_id (int): ID of the tool.
        tool (ToolIn): Updated tool data.

    Returns:
        Tool: The updated tool details.

    Raises:
        HTTPException (status_code=404): If the tool with the specified ID is not found.

    """

    db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
    if not db_tool:
        raise HTTPException(status_code=404, detail="Tool not found")

    db_tool.name = tool.name
    db_tool.folder_name = tool.folder_name
    db_tool.class_name = tool.class_name
    db_tool.file_name = tool.file_name

    db.session.add(db_tool)
    db.session.commit()
    return db_tool

  

代码中定义了一个名为 tool_router 的 FastAPI 路由器,用于管理工具(Tool)的添加、查询、更新和删除。以下是每个路由的关键实现和要点:

1. 添加工具 (/add):

  • 方法: POST
  • 功能: 创建一个新的工具。
  • 输入: 接收一个 ToolIn 类型的对象,包含工具的名称、文件夹名、类名和文件名。
  • 处理: 创建一个 Tool 实例并将其添加到数据库中。
  • 返回: 返回创建的工具的详细信息,使用 ToolOut 模型序列化。
  • 异常: 如果创建过程中出现问题,会抛出 HTTP 400 错误。
   @router.post("/add"response_model=ToolOut, status_code=201)
   def create_tool(tool: ToolIn, Authorize: AuthJWT = Depends(check_auth)):
       db_tool = Tool(
           name=tool.name,
           folder_name=tool.folder_name,
           class_name=tool.class_name,
           file_name=tool.file_name,
       )
       db.session.add(db_tool)
       db.session.commit()
       return db_tool

2. 查询工具 (/get/{tool_id}):

  • 方法: GET
  • 功能: 根据工具 ID 获取工具的详细信息。
  • 输入: 工具的 ID。
  • 处理: 查询数据库中的对应工具。
  • 返回: 返回工具的详细信息,使用 ToolOut 模型序列化。
  • 异常: 如果工具不存在,会抛出 HTTP 404 错误。
   @router.get("/get/{tool_id}"response_model=ToolOut)
   def get_tool(tool_idintAuthorize: AuthJWT = Depends(check_auth)):
       db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
       if not db_tool:
           raise HTTPException(status_code=404detail="Tool not found")
       return db_tool
 

3. 更新工具 (/update/{tool_id}):

  • 方法: PUT
  • 功能: 更新指定 ID 的工具。
  • 输入: 工具的 ID 和新的工具信息(ToolIn 类型)。
  • 处理: 查找并更新工具的信息。
  • 返回: 返回更新后的工具信息,使用 ToolOut 模型序列化。
  • 异常: 如果工具不存在,会抛出 HTTP 404 错误。
   @router.put("/update/{tool_id}"response_model=ToolOut)
   def update_tool(tool_idinttool: ToolIn, Authorize: AuthJWT = Depends(check_auth)):
       db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
       if not db_tool:
           raise HTTPException(status_code=404detail="Tool not found")
       db_tool.name = tool.name
       db_tool.folder_name = tool.folder_name
       db_tool.class_name = tool.class_name
       db_tool.file_name = tool.file_name
       db.session.add(db_tool)
       db.session.commit()
       return db_tool
 

4. 删除工具:

  • 代码中没有直接提供删除工具的路由,但通常这会通过一个 DELETE 请求实现,类似于更新和查询,只不过是从数据库中移除记录。

这些路由和控制器共同提供了完整的 CRUD 功能,使得用户可以通过 API 管理工具的整个生命周期。

 

from sqlalchemy import Column, Integer, String

from superagi.models.base_model import DBBaseModel


# from pydantic import BaseModel

class Tool(DBBaseModel):
    """
    Model representing a tool.

    Attributes:
        id (Integer): The primary key of the tool.
        name (String): The name of the tool.
        folder_name (String): The folder name of the tool.
        class_name (String): The class name of the tool.
        file_name (String): The file name of the tool.
    """

    __tablename__ = 'tools'

    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(String)
    description = Column(String)
    folder_name = Column(String)
    class_name = Column(String)
    file_name = Column(String)
    toolkit_id = Column(Integer)

    def __repr__(self):
        """
        Returns a string representation of the Tool object.

        Returns:
            str: String representation of the Tool object.
        """

        return f"Tool(id={self.id}, name='{self.name}',description='{self.description}' folder_name='{self.folder_name}'," \
               f" file_name = {self.file_name}, class_name='{self.class_name}, toolkit_id={self.toolkit_id}')"

    def to_dict(self):
        """
        Convert the Tool instance to a dictionary.

        Returns:
            dict: A dictionary representation of the Tool instance.
        """
        return {
            "id": self.id,
            "name": self.name,
            "description": self.description,
            "folder_name": self.folder_name,
            "class_name": self.class_name,
            "file_name": self.file_name,
            "toolkit_id": self.toolkit_id
        }
    @staticmethod
    def add_or_update(session, tool_name: str, description: str, folder_name: str, class_name: str, file_name: str,
                      toolkit_id: int):
        # Check if a record with the given tool name already exists inside a toolkit
        tool = session.query(Tool).filter_by(name=tool_name,
                                             toolkit_id=toolkit_id).first()
        if tool is not None:
            # Update the attributes of the existing tool record
            tool.folder_name = folder_name
            tool.class_name = class_name
            tool.file_name = file_name
            tool.description = description
        else:
            # Create a new tool record
            tool = Tool(name=tool_name, description=description, folder_name=folder_name, class_name=class_name,
                        file_name=file_name,
                        toolkit_id=toolkit_id)
            session.add(tool)

        session.commit()
        session.flush()
        return tool

    @staticmethod
    def delete_tool(session, tool_name):
        tool = session.query(Tool).filter(Tool.name == tool_name).first()
        if tool:
            session.delete(tool)
            session.commit()
            session.flush()

    @classmethod
    def convert_tool_names_to_ids(cls, db, tool_names):
        """
        Converts a list of tool names to their corresponding IDs.

        Args:
            db: The database session.
            tool_names (list): List of tool names.

        Returns:
            list: List of tool IDs.
        """

        tools = db.session.query(Tool).filter(Tool.name.in_(tool_names)).all()
        return [tool.id for tool in tools]

    @classmethod
    def convert_tool_ids_to_names(cls, db, tool_ids):
        """
        Converts a list of tool IDs to their corresponding names.

        Args:
            db: The database session.
            tool_ids (list): List of tool IDs.

        Returns:
            list: List of tool names.
        """

        tools = db.session.query(Tool).filter(Tool.id.in_(tool_ids)).all()
        return [str(tool.name) for tool in tools]

    @classmethod
    def get_invalid_tools(cls, tool_ids, session):
        invalid_tool_ids = []
        for tool_id in tool_ids:
            tool = session.query(Tool).get(tool_id)
            if tool is None:
                invalid_tool_ids.append(tool_id)
        return invalid_tool_ids

    @classmethod
    def get_toolkit_tools(cls, session, toolkit_id : int):
        return session.query(Tool).filter(Tool.toolkit_id == toolkit_id).all()

  

 superAGI支持的工具:

 

工具示例:

 

继续分析下Tool类的实现。

Tool 类是一个 SQLAlchemy 模型,用于表示数据库中的 tools 表。以下是该类的关键实现原理和要点:

1. 属性定义:

  • 类中定义了多个属性,每个属性都映射到数据库表的一个列。
  • 使用 Column 类定义列的类型(如 Integer, String)和特性(如 primary_key, autoincrement)。
   id = Column(Integer, primary_key=Trueautoincrement=True)
   name = Column(String)
   description = Column(String)
   folder_name = Column(String)
   class_name = Column(String)
   file_name = Column(String)
   toolkit_id = Column(Integer)
 

2. 表名定义:

  • __tablename__ 属性用于指定这个模型映射到数据库中的表名。
   __tablename__ = 'tools'

3. 字符串表示 (__repr__ 方法):

  • 重写 __repr__ 方法以提供类的实例的友好字符串表示,便于调试和日志记录。
   def __repr__(self):
       return f"Tool(id={self.id}, name='{self.name}', description='{self.description}', folder_name='{self.folder_name}', class_name='{self.class_name}', file_name='{self.file_name}', toolkit_id={self.toolkit_id})"
 

4. 字典转换 (to_dict 方法):

  • 提供 to_dict 方法将模型实例转换为字典,这在处理 HTTP 响应时非常有用。
   def to_dict(self):
       return {
           "id"self.id,
           "name"self.name,
           "description"self.description,
           "folder_name"self.folder_name,
           "class_name"self.class_name,
           "file_name"self.file_name,
           "toolkit_id"self.toolkit_id
       }
 

5. 添加或更新工具 (add_or_update 静态方法):

  • 这个方法首先尝试查询数据库中是否存在具有相同名称和 toolkit_id 的工具。
  • 如果存在,则更新该工具的信息;如果不存在,则创建一个新的工具记录并添加到数据库。
   @staticmethod
   def add_or_update(sessiontool_namedescriptionfolder_nameclass_namefile_nametoolkit_id):
       tool = session.query(Tool).filter_by(name=tool_name, toolkit_id=toolkit_id).first()
       if tool is not None:
           tool.folder_name = folder_name
           tool.class_name = class_name
           tool.file_name = file_name
           tool.description = description
       else:
           tool = Tool(name=tool_name, description=description, folder_name=folder_name, class_name=class_name, file_name=file_name, toolkit_id=toolkit_id)
           session.add(tool)
       session.commit()
       session.flush()
       return tool
 

6. 删除工具 (delete_tool 静态方法):

  • 提供一个方法来删除指定名称的工具。
   @staticmethod
   def delete_tool(sessiontool_name):
       tool = session.query(Tool).filter(Tool.name == tool_name).first()
       if tool:
           session.delete(tool)
           session.commit()
           session.flush()

这些实现细节展示了 Tool 类如何与数据库交互,包括如何创建、更新、查询和删除记录。这是构建 RESTful API 的基础,允许通过 HTTP 请求管理数据库记录。

 

tool的具体应用:

class AgentWorkflowSeed:
    @classmethod
    def build_sales_workflow(cls, session):
        agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Sales Engagement Workflow",
                                                              "Sales Engagement Workflow")
        # step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
        #                                                             str(agent_workflow.id) + "_step1",
        #                                                             ApolloSearchTool().name,
        #                                                             "Search for leads based on the given goals",
        #                                                             step_type="TRIGGER")
        #
        step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step2",
                                                                    ListFileTool().name,
                                                                    "list the files",
                                                                    step_type="TRIGGER")

        step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step3",
                                                                    ReadFileTool().name,
                                                                    "Read the leads from the file")

        # task queue ends when the elements gets over
        step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step4",
                                                                    "TASK_QUEUE",
                                                                    "Break the above response array of items",
                                                                    completion_prompt="Get array of items from the above response. Array should suitable utilization of JSON.parse().")

        step5 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step5",
                                                                    GoogleSearchTool().name,
                                                                    "Search about the company in which the lead is working")

        step6 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step6",
                                                                    "WAIT_FOR_PERMISSION",
                                                                    "Email will be based on this content. Do you want send the email?")

        step7 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step7",
                                                                    SearxSearchTool().name,
                                                                    "Search about the company given in the high-end goal only")

        step8 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step8",
                                                                    SendEmailTool().name,
                                                                    "Customize the Email according to the company information in the mail")

        step9 = AgentWorkflowStep.find_or_create_wait_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step9",
                                                                    "Wait for 2 minutes",
                                                                    2*60)

        step10 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step10",
                                                                     ReadEmailTool().name,
                                                                     "Read the email from adarshdeepmurari@gmail.com")

        step11 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step11",
                                                                    SendEmailTool().name,
                                                                    "Customize the Email according to the company information in the mail")

        # AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
        AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id)
        AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id)
        AgentWorkflowStep.add_next_workflow_step(session, step4.id, -1, "COMPLETE")
        AgentWorkflowStep.add_next_workflow_step(session, step4.id, step5.id)
        AgentWorkflowStep.add_next_workflow_step(session, step5.id, step6.id)
        AgentWorkflowStep.add_next_workflow_step(session, step6.id, step7.id, "YES")
        AgentWorkflowStep.add_next_workflow_step(session, step6.id, step5.id, "NO")
        AgentWorkflowStep.add_next_workflow_step(session, step7.id, step8.id)
        AgentWorkflowStep.add_next_workflow_step(session, step8.id, step9.id)
        AgentWorkflowStep.add_next_workflow_step(session, step9.id, step10.id)
        AgentWorkflowStep.add_next_workflow_step(session, step10.id, step11.id)
        AgentWorkflowStep.add_next_workflow_step(session, step11.id, step4.id)
        session.commit()

    @classmethod
    def build_recruitment_workflow(cls, session):
        agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Recruitment Workflow",
                                                              "Recruitment Workflow")
        step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step1",
                                                                    ListFileTool().name,
                                                                    "List the files from the resource manager",
                                                                    step_type="TRIGGER")

        # task queue ends when the elements gets over
        step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step2",
                                                                    "TASK_QUEUE",
                                                                    "Break the above response array of items",
                                                                    completion_prompt="Get array of items from the above response. Array should suitable utilization of JSON.parse(). Skip job_description file from list.")

        step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step3",
                                                                    ReadFileTool().name,
                                                                    "Read the resume from above input",
                                                                    "Check if the resume matches High-Level GOAL")

        step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step4",
                                                                    SendEmailTool().name,
                                                                    "Write a custom acceptance Email to the candidates")

        step5 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
                                                                    str(agent_workflow.id) + "_step5",
                                                                    SendEmailTool().name,
                                                                    "Write a custom Reject Email to the candidates")

        AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
        AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id)
        AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE")
        AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id, "YES")
        AgentWorkflowStep.add_next_workflow_step(session, step3.id, step5.id, "NO")
        AgentWorkflowStep.add_next_workflow_step(session, step4.id, step2.id)
        AgentWorkflowStep.add_next_workflow_step(session, step5.id, step2.id)
        session.commit()

 

AgentWorkflowSeed 类通过定义一系列的步骤和步骤之间的转移逻辑,实现了多种类型的工作流。这些工作流可以广泛应用于不同的业务场景,如销售、招聘和编码等。每个工作流都是通过动态地关联工具和操作来构建的,使得整个系统具有很高的灵活性和可配置性。  

 

另外工具集合形成ToolKit,如下:

import json

import requests
from sqlalchemy import Column, Integer, String, Boolean

from superagi.models.base_model import DBBaseModel
from superagi.models.tool import Tool

marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"


class Toolkit(DBBaseModel):
    """
        ToolKit - Used to group tools together
        Attributes:
            id(int) : id of the tool kit
            name(str) : name of the tool kit
            description(str) : description of the tool kit
            show_toolkit(boolean) : indicates whether the tool kit should be shown based on the count of tools in the toolkit
            organisation_id(int) : org id of the to which tool config is related
            tool_code_link(str) : stores Github link for toolkit
    """
    __tablename__ = 'toolkits'

    id = Column(Integer, primary_key=True)
    name = Column(String)
    description = Column(String)
    show_toolkit = Column(Boolean)
    organisation_id = Column(Integer)
    tool_code_link = Column(String)

    def __repr__(self):
        return f"ToolKit(id={self.id}, name='{self.name}', description='{self.description}', " \
               f"show_toolkit={self.show_toolkit}," \
               f"organisation_id = {self.organisation_id}"

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'description': self.description,
            'show_toolkit': self.show_toolkit,
            'organisation_id': self.organisation_id
        }

    def to_json(self):
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, json_data):
        data = json.loads(json_data)
        return cls(
            id=data['id'],
            name=data['name'],
            description=data['description'],
            show_toolkit=data['show_toolkit'],
            organisation_id=data['organisation_id']
        )

    @staticmethod
    def add_or_update(session, name, description, show_toolkit, organisation_id, tool_code_link):
        # Check if the toolkit exists
        toolkit = session.query(Toolkit).filter(Toolkit.name == name,
                                                Toolkit.organisation_id == organisation_id).first()

        if toolkit:
            # Update the existing toolkit
            toolkit.name = name
            toolkit.description = description
            toolkit.show_toolkit = show_toolkit
            toolkit.organisation_id = organisation_id
            toolkit.tool_code_link = tool_code_link
        else:
            # Create a new toolkit
            toolkit = Toolkit(
                name=name,
                description=description,
                show_toolkit=show_toolkit,
                organisation_id=organisation_id,
                tool_code_link=tool_code_link
            )

            session.add(toolkit)

        session.commit()
        session.flush()
        return toolkit

    @classmethod
    def fetch_marketplace_list(cls, page):
        headers = {'Content-Type': 'application/json'}
        response = requests.get(
            marketplace_url + f"/toolkits/marketplace/list/{str(page)}",
            headers=headers, timeout=10)
        if response.status_code == 200:
            return response.json()
        else:
            return []

    @classmethod
    def fetch_marketplace_detail(cls, search_str, toolkit_name):
        headers = {'Content-Type': 'application/json'}
        search_str = search_str.replace(' ', '%20')
        toolkit_name = toolkit_name.replace(' ', '%20')
        response = requests.get(
            marketplace_url + f"/toolkits/marketplace/{search_str}/{toolkit_name}",
            headers=headers, timeout=10)
        if response.status_code == 200:
            return response.json()
        else:
            return None

    @staticmethod
    def get_toolkit_from_name(session, toolkit_name, organisation):
        toolkit = session.query(Toolkit).filter_by(name=toolkit_name, organisation_id=organisation.id).first()
        if toolkit:
            return toolkit
        return None

    @classmethod
    def get_toolkit_installed_details(cls, session, marketplace_toolkits, organisation):
        installed_toolkits = session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
        for toolkit in marketplace_toolkits:
            if toolkit['name'] in [installed_toolkit.name for installed_toolkit in installed_toolkits]:
                toolkit["is_installed"] = True
            else:
                toolkit["is_installed"] = False
        return marketplace_toolkits

    @classmethod
    def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids):
        agent_toolkit_tools = []
        for toolkit_id in toolkit_ids:
            toolkit_tools = session.query(Tool).filter(Tool.toolkit_id == toolkit_id).all()
            for tool in toolkit_tools:
                tool = session.query(Tool).filter(Tool.id == tool.id).first()
                if tool is not None:
                    agent_toolkit_tools.append(tool.id)
        return agent_toolkit_tools

    @classmethod
    def get_tool_and_toolkit_arr(cls, session, organisation_id :int,agent_config_tools_arr: list):
        from superagi.models.tool import Tool
        toolkits_arr= set()
        tools_arr= set()
        for tool_obj in agent_config_tools_arr:
            toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip(), Toolkit.organisation_id == organisation_id).first()
            if toolkit is None:
                raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")
            toolkits_arr.add(toolkit.id)
            if tool_obj.get("tools"):
                for tool_name_str in tool_obj["tools"]:
                    tool_db_obj = session.query(Tool).filter(Tool.name == tool_name_str.strip(),
                                                             Tool.toolkit_id == toolkit.id).first()
                    if tool_db_obj is None:
                            raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")

                    tools_arr.add(tool_db_obj.id)
            else:
                tools=Tool.get_toolkit_tools(session, toolkit.id)
                for tool_db_obj in tools:
                    tools_arr.add(tool_db_obj.id)
        return list(tools_arr)

  

Toolkit 类在应用程序中扮演着管理工具包(toolkits)的核心角色。具体来说,它的作用包括:

1. 数据模型定义:

  • Toolkit 类定义了工具包的数据结构,包括工具包的名称、描述、是否显示、所属组织的 ID 和工具代码链接等属性。这些属性映射到数据库中的 toolkits 表的列。

2. 数据库交互:

  • 该类通过 SQLAlchemy 提供的 ORM 功能,实现了对数据库中工具包数据的增删改查操作。例如,可以创建新的工具包、更新现有工具包的信息、查询工具包详情或列表,以及删除工具包。

3. 业务逻辑封装:

  • 类中的方法如 add_or_update 封装了业务逻辑,例如检查同一组织下是否已存在同名的工具包,如果存在则更新,不存在则创建新的工具包。这样的封装使得业务逻辑集中管理,便于维护和修改。

4. 数据格式转换:

  • 提供了 to_dict 和 to_json 方法,允许将工具包对象转换为字典或 JSON 格式,这对于生成 HTTP 响应非常有用。同时,from_json 方法允许从 JSON 格式数据创建工具包对象,便于处理来自前端的数据。

5. 接口支持:

  • 通过定义如 fetch_marketplace_list 和 fetch_marketplace_detail 等类方法,Toolkit 类还支持与外部系统(如市场平台)的接口交互,获取市场上的工具包列表或详细信息。

总之,Toolkit 类是应用程序中管理工具包的关键组件,它不仅负责数据的持久化和检索,还处理与工具包相关的业务逻辑,提供数据的格式化输出,以及支持与外部系统的交互。这使得 Toolkit 成为构建和维护工具包功能的基础。

其中,marketplace的工具:

 toolkit里有多个工具集,如下:

 

根据toolkit的功能,可以使用他的模板来创建一个agent:

按照模板创建agent:

 

toolkit_router的功能和tools功能差不多,都是数据库CRUD操作。
 

补充下,toolconfig的作用,针对特定的toolkit进行配置(kv设置,敏感数据支持加密):

@router.post("/add/{toolkit_name}", status_code=201)
def update_tool_config(toolkit_name: str, configs: list, organisation: Organisation = Depends(get_user_organisation)):
    """
    Update tool configurations for a specific tool kit.

    Args:
        toolkit_name (str): The name of the tool kit.
        configs (list): A list of dictionaries containing the tool configurations.
            Each dictionary should have the following keys:
            - "key" (str): The key of the configuration.
            - "value" (str): The new value for the configuration.

    Returns:
        dict: A dictionary with the message "Tool configs updated successfully".

    Raises:
        HTTPException (status_code=404): If the specified tool kit is not found.
        HTTPException (status_code=500): If an unexpected error occurs during the update process.
    """

    try:
        # Check if the tool kit exists
        toolkit = Toolkit.get_toolkit_from_name(db.session, toolkit_name,organisation)
        if toolkit is None:
            raise HTTPException(status_code=404, detail="Tool kit not found")

        # Update existing tool configs
        for config in configs:
            key = config.get("key")
            value = config.get("value")
            if value is None: 
                continue
            if key is not None:
                tool_config = db.session.query(ToolConfig).filter_by(toolkit_id=toolkit.id, key=key).first()
                if tool_config:
                    if tool_config.key_type ==  ToolConfigKeyType.FILE.value:
                        value = json.dumps(value)
                    # Update existing tool config
                    # added encryption
                    tool_config.value = encrypt_data(value)
                    db.session.commit()

        return {"message": "Tool configs updated successfully"}

    except Exception as e:
        # db.session.rollback()
        raise HTTPException(status_code=500, detail=str(e))

 

 

接下来是agent习惯的操作,agent.py文件:

from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from sqlalchemy import desc
import ast

from pytz import timezone
from sqlalchemy import func, or_
from superagi.models.agent import Agent
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.agent_template import AgentTemplate
from superagi.models.project import Project
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent_execution import AgentExecution
from superagi.models.tool import Tool
from superagi.controllers.types.agent_schedule import AgentScheduleInput
from superagi.controllers.types.agent_with_config import AgentConfigInput
from superagi.controllers.types.agent_with_config_schedule import AgentConfigSchedule
from jsonmerge import merge
from datetime import datetime
import json

from superagi.models.toolkit import Toolkit
from superagi.models.knowledges import Knowledges

from sqlalchemy import func
# from superagi.types.db import AgentOut, AgentIn
from superagi.helper.auth import check_auth
from superagi.apm.event_handler import EventHandler
from superagi.models.workflows.iteration_workflow import IterationWorkflow

router = APIRouter()


class AgentOut(BaseModel):
    id: int
    name: str
    project_id: int
    description: str
    created_at: datetime
    updated_at: datetime

    class Config:
        orm_mode = True


class AgentIn(BaseModel):
    name: str
    project_id: int
    description: str

    class Config:
        orm_mode = True


@router.post("/create", status_code=201)
def create_agent_with_config(agent_with_config: AgentConfigInput,
                             Authorize: AuthJWT = Depends(check_auth)):
    """
    Create a new agent with configurations.

    Args:
        agent_with_config (AgentConfigInput): Data for creating a new agent with configurations.
            - name (str): Name of the agent.
            - project_id (int): Identifier of the associated project.
            - description (str): Description of the agent.
            - goal (List[str]): List of goals for the agent.
            - constraints (List[str]): List of constraints for the agent.
            - tools (List[int]): List of tool identifiers associated with the agent.
            - exit (str): Exit condition for the agent.
            - iteration_interval (int): Interval between iterations for the agent.
            - model (str): Model information for the agent.
            - permission_type (str): Permission type for the agent.
            - LTM_DB (str): LTM database for the agent.
            - max_iterations (int): Maximum number of iterations for the agent.
            - user_timezone (string): Timezone of the user

    Returns:
        dict: Dictionary containing the created agent's ID, execution ID, name, and content type.

    Raises:
        HTTPException (status_code=404): If the associated project or any of the tools is not found.
    """

    project = db.session.query(Project).get(agent_with_config.project_id)
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")

    invalid_tools = Tool.get_invalid_tools(agent_with_config.tools, db.session)
    if len(invalid_tools) > 0:  # If the returned value is not True (then it is an invalid tool_id)
        raise HTTPException(status_code=404,
                           
                            detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")

    agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
                                                              toolkit_ids=agent_with_config.toolkits)
    agent_with_config.tools.extend(agent_toolkit_tools)
    db_agent = Agent.create_agent_with_config(db, agent_with_config)

    start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
    iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
                                                                start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1

    # Creating an execution with RUNNING status
    execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
                               name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)

    agent_execution_configs = {
        "goal": agent_with_config.goal,
        "instruction": agent_with_config.instruction,
        "constraints": agent_with_config.constraints,
        "toolkits": agent_with_config.toolkits,
        "exit": agent_with_config.exit,
        "tools": agent_with_config.tools,
        "iteration_interval": agent_with_config.iteration_interval,
        "model": agent_with_config.model,
        "permission_type": agent_with_config.permission_type,
        "LTM_DB": agent_with_config.LTM_DB,
        "max_iterations": agent_with_config.max_iterations,
        "user_timezone": agent_with_config.user_timezone,
        "knowledge": agent_with_config.knowledge
    }
    db.session.add(execution)
    db.session.commit()
    db.session.flush()
    AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
                                                                     agent_execution_configs=agent_execution_configs)

    agent = db.session.query(Agent).filter(Agent.id == db_agent.id,  ).first()
    organisation = agent.get_agent_organisation(db.session)
    
    EventHandler(session=db.session).create_event('run_created', 
                                                  {'agent_execution_id': execution.id,
                                                   'agent_execution_name':  execution.name},
                                                    db_agent.id,
                                                    organisation.id if organisation else 0),

    if agent_with_config.knowledge:
        knowledge_name = db.session.query(Knowledges.name).filter(Knowledges.id == agent_with_config.knowledge).first()[0]
        EventHandler(session=db.session).create_event('knowledge_picked', 
                                                      {'knowledge_name': knowledge_name, 
                                                        'agent_execution_id': execution.id},
                                                      db_agent.id, 
                                                      organisation.id if organisation else 0)
    
    EventHandler(session=db.session).create_event('agent_created', 
                                                  {'agent_name': agent_with_config.name,
                                                   'model': agent_with_config.model}, 
                                                  db_agent.id,
                                                  organisation.id if organisation else 0)

    db.session.commit()

    return {
        "id": db_agent.id,
        "execution_id": execution.id,
        "name": db_agent.name,
        "contentType": "Agents"
    }



@router.post("/schedule", status_code=201)
def create_and_schedule_agent(agent_config_schedule: AgentConfigSchedule,
                              Authorize: AuthJWT = Depends(check_auth)):
    """
    Create a new agent with configurations and scheduling.

    Args:
        agent_with_config_schedule (AgentConfigSchedule): Data for creating a new agent with configurations and scheduling.

    Returns:
        dict: Dictionary containing the created agent's ID, name, content type and schedule ID of the agent.

    Raises:
        HTTPException (status_code=500): If the associated agent fails to get scheduled.
    """

    project = db.session.query(Project).get(agent_config_schedule.agent_config.project_id)
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")
    agent_config = agent_config_schedule.agent_config
    invalid_tools = Tool.get_invalid_tools(agent_config.tools, db.session)
    if len(invalid_tools) > 0:  # If the returned value is not True (then it is an invalid tool_id)
        raise HTTPException(status_code=404,
                           
                            detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")

    agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
                                                              toolkit_ids=agent_config.toolkits)
    agent_config.tools.extend(agent_toolkit_tools)
    db_agent = Agent.create_agent_with_config(db, agent_config)

    # Update the agent_id of schedule before scheduling the agent
    agent_schedule = agent_config_schedule.schedule

    # Create a new agent schedule
    agent_schedule = AgentSchedule(
        agent_id=db_agent.id,
        start_time=agent_schedule.start_time,
        next_scheduled_time=agent_schedule.start_time,
        recurrence_interval=agent_schedule.recurrence_interval,
        expiry_date=agent_schedule.expiry_date,
        expiry_runs=agent_schedule.expiry_runs,
        current_runs=0,
        status="SCHEDULED"
    )

    agent_schedule.agent_id = db_agent.id
    db.session.add(agent_schedule)
    db.session.commit()

    if agent_schedule.id is None:
        raise HTTPException(status_code=500, detail="Failed to schedule agent")

    agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
    organisation = agent.get_agent_organisation(db.session)

    EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_config.name,
                                                                        'model': agent_config.model}, db_agent.id,
                                                      organisation.id if organisation else 0)

    db.session.commit()

    return {
        "id": db_agent.id,
        "name": db_agent.name,
        "contentType": "Agents",
        "schedule_id": agent_schedule.id
    }



@router.post("/stop/schedule", status_code=200)
def stop_schedule(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
    Stopping the scheduling for a given agent.

    Args:
        agent_id (int): Identifier of the Agent
        Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.
    """

    agent_to_delete = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
                                                             AgentSchedule.status == "SCHEDULED").first()
    if not agent_to_delete:
        raise HTTPException(status_code=404, detail="Schedule not found")
    agent_to_delete.status = "STOPPED"
    db.session.commit()


@router.put("/edit/schedule", status_code=200)
def edit_schedule(schedule: AgentScheduleInput,
                  Authorize: AuthJWT = Depends(check_auth)):
    """
    Edit the scheduling for a given agent.

    Args:
        agent_id (int): Identifier of the Agent
        schedule (AgentSchedule): New schedule data
        Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.
    """

    agent_to_edit = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == schedule.agent_id, AgentSchedule.status == "SCHEDULED").first()
                        
    if not agent_to_edit:
        raise HTTPException(status_code=404, detail="Schedule not found")

    # Update agent schedule with new data
    agent_to_edit.start_time = schedule.start_time
    agent_to_edit.next_scheduled_time = schedule.start_time
    agent_to_edit.recurrence_interval = schedule.recurrence_interval
    agent_to_edit.expiry_date = schedule.expiry_date
    agent_to_edit.expiry_runs = schedule.expiry_runs

    db.session.commit()


@router.get("/get/schedule_data/{agent_id}")
def get_schedule_data(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
    Get the scheduling data for a given agent.

    Args:
        agent_id (int): Identifier of the Agent

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.

    Returns:
        current_datetime (DateTime): Current Date and Time.
        recurrence_interval (String): Time interval for recurring schedule run.
        expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
        expiry_runs (Integer): The number of runs before the agent expires.
    """
    agent = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
                                                   AgentSchedule.status == "SCHEDULED").first()

    if not agent:
        raise HTTPException(status_code=404, detail="Agent Schedule not found")

    user_timezone = db.session.query(AgentConfiguration).filter(AgentConfiguration.key == "user_timezone",
                                                                AgentConfiguration.agent_id == agent_id).first()

    if user_timezone and user_timezone.value != "None":
        tzone = timezone(user_timezone.value)
    else:
        tzone = timezone('GMT')

    current_datetime = datetime.now(tzone).strftime("%d/%m/%Y %I:%M %p")

    return {
        "current_datetime": current_datetime,
        "start_date": agent.start_time.astimezone(tzone).strftime("%d %b %Y"),
        "start_time": agent.start_time.astimezone(tzone).strftime("%I:%M %p"),
        "recurrence_interval": agent.recurrence_interval if agent.recurrence_interval else None,
        "expiry_date": agent.expiry_date.astimezone(tzone).strftime("%d/%m/%Y") if agent.expiry_date else None,
        "expiry_runs": agent.expiry_runs if agent.expiry_runs != -1 else None
    }


@router.get("/get/project/{project_id}")
def get_agents_by_project_id(project_id: int,
                             Authorize: AuthJWT = Depends(check_auth)):
    """
    Get all agents by project ID.

    Args:
        project_id (int): Identifier of the project.
        Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).

    Returns:
        list: List of agents associated with the project, including their status and scheduling information.

    Raises:
        HTTPException (status_code=404): If the project is not found.
    """

    # Checking for project
    project = db.session.query(Project).get(project_id)
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")

    agents = db.session.query(Agent).filter(Agent.project_id == project_id, or_(or_(Agent.is_deleted == False, Agent.is_deleted is None), Agent.is_deleted is None)).all()

    new_agents, new_agents_sorted = [], []
    for agent in agents:
        agent_dict = vars(agent)

        agent_id = agent.id

        # Query the AgentExecution table using the agent ID
        executions = db.session.query(AgentExecution).filter_by(agent_id=agent_id).all()
        is_running = False
        for execution in executions:
            if execution.status == "RUNNING":
                is_running = True
                break
        # Check if the agent is scheduled
        is_scheduled = db.session.query(AgentSchedule).filter_by(agent_id=agent_id, status="SCHEDULED").first() is not None
                                                                 

        new_agent = {
            **agent_dict,
            'is_running': is_running,
            'is_scheduled': is_scheduled
        }
        new_agents.append(new_agent)
        new_agents_sorted = sorted(new_agents, key=lambda agent: agent['is_running'] == True, reverse=True)
    return new_agents_sorted


@router.put("/delete/{agent_id}", status_code=200)
def delete_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
        Delete an existing Agent
            - Updates the is_deleted flag: Executes a soft delete
            - AgentExecutions are updated to: "TERMINATED" if agentexecution is created, All the agent executions are updated
            - AgentExecutionPermission is set to: "REJECTED" if agentexecutionpersmision is created
            
        Args:
            agent_id (int): Identifier of the Agent to delete

        Returns:
            A dictionary containing a "success" key with the value True to indicate a successful delete.

        Raises:
            HTTPException (Status Code=404): If the Agent or associated Project is not found or deleted already.
    """

    db_agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
    db_agent_executions = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).all()
    db_agent_schedule = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id, AgentSchedule.status == "SCHEDULED").first()
    
    if not db_agent or db_agent.is_deleted:
        raise HTTPException(status_code=404, detail="agent not found")

    # Deletion Procedure 
    db_agent.is_deleted = True
    if db_agent_executions:
        # Updating all the RUNNING executions to TERMINATED
        for db_agent_execution in db_agent_executions:
            db_agent_execution.status = "TERMINATED"

    if db_agent_schedule:
        # Updating the schedule status to STOPPED
        db_agent_schedule.status = "STOPPED"
    
    db.session.commit()

 

app.include_router(agent_router, prefix="/agents") 可以看到其对应url是agents。
 
创建agent的函数:
@router.post("/create", status_code=201)
def create_agent_with_config(agent_with_config: AgentConfigInput,
                             Authorize: AuthJWT = Depends(check_auth)):
    """
    Create a new agent with configurations.

    Args:
        agent_with_config (AgentConfigInput): Data for creating a new agent with configurations.
            - name (str): Name of the agent.
            - project_id (int): Identifier of the associated project.
            - description (str): Description of the agent.
            - goal (List[str]): List of goals for the agent.
            - constraints (List[str]): List of constraints for the agent.
            - tools (List[int]): List of tool identifiers associated with the agent.
            - exit (str): Exit condition for the agent.
            - iteration_interval (int): Interval between iterations for the agent.
            - model (str): Model information for the agent.
            - permission_type (str): Permission type for the agent.
            - LTM_DB (str): LTM database for the agent.
            - max_iterations (int): Maximum number of iterations for the agent.
            - user_timezone (string): Timezone of the user

    Returns:
        dict: Dictionary containing the created agent's ID, execution ID, name, and content type.

    Raises:
        HTTPException (status_code=404): If the associated project or any of the tools is not found.
    """

    project = db.session.query(Project).get(agent_with_config.project_id)
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")

    invalid_tools = Tool.get_invalid_tools(agent_with_config.tools, db.session)
    if len(invalid_tools) > 0:  # If the returned value is not True (then it is an invalid tool_id)
        raise HTTPException(status_code=404,
                           
                            detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")

    agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
                                                              toolkit_ids=agent_with_config.toolkits)
    agent_with_config.tools.extend(agent_toolkit_tools)
    db_agent = Agent.create_agent_with_config(db, agent_with_config)

    start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
    iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
                                                                start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1

    # Creating an execution with RUNNING status
    execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
                               name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)

    agent_execution_configs = {
        "goal": agent_with_config.goal,
        "instruction": agent_with_config.instruction,
        "constraints": agent_with_config.constraints,
        "toolkits": agent_with_config.toolkits,
        "exit": agent_with_config.exit,
        "tools": agent_with_config.tools,
        "iteration_interval": agent_with_config.iteration_interval,
        "model": agent_with_config.model,
        "permission_type": agent_with_config.permission_type,
        "LTM_DB": agent_with_config.LTM_DB,
        "max_iterations": agent_with_config.max_iterations,
        "user_timezone": agent_with_config.user_timezone,
        "knowledge": agent_with_config.knowledge
    }
    db.session.add(execution)
    db.session.commit()
    db.session.flush()
    AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
                                                                     agent_execution_configs=agent_execution_configs)

    agent = db.session.query(Agent).filter(Agent.id == db_agent.id,  ).first()
    organisation = agent.get_agent_organisation(db.session)
    
    EventHandler(session=db.session).create_event('run_created', 
                                                  {'agent_execution_id': execution.id,
                                                   'agent_execution_name':  execution.name},
                                                    db_agent.id,
                                                    organisation.id if organisation else 0),

    if agent_with_config.knowledge:
        knowledge_name = db.session.query(Knowledges.name).filter(Knowledges.id == agent_with_config.knowledge).first()[0]
        EventHandler(session=db.session).create_event('knowledge_picked', 
                                                      {'knowledge_name': knowledge_name, 
                                                        'agent_execution_id': execution.id},
                                                      db_agent.id, 
                                                      organisation.id if organisation else 0)
    
    EventHandler(session=db.session).create_event('agent_created', 
                                                  {'agent_name': agent_with_config.name,
                                                   'model': agent_with_config.model}, 
                                                  db_agent.id,
                                                  organisation.id if organisation else 0)

    db.session.commit()

    return {
        "id": db_agent.id,
        "execution_id": execution.id,
        "name": db_agent.name,
        "contentType": "Agents"
    }

  

可以看到是根据toolkit模板创建agent,也就是:

 

agent调度的几个函数:

@router.post("/schedule", status_code=201)
def create_and_schedule_agent(agent_config_schedule: AgentConfigSchedule,
                              Authorize: AuthJWT = Depends(check_auth)):
    """
    Create a new agent with configurations and scheduling.

    Args:
        agent_with_config_schedule (AgentConfigSchedule): Data for creating a new agent with configurations and scheduling.

    Returns:
        dict: Dictionary containing the created agent's ID, name, content type and schedule ID of the agent.

    Raises:
        HTTPException (status_code=500): If the associated agent fails to get scheduled.
    """

    project = db.session.query(Project).get(agent_config_schedule.agent_config.project_id)
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")
    agent_config = agent_config_schedule.agent_config
    invalid_tools = Tool.get_invalid_tools(agent_config.tools, db.session)
    if len(invalid_tools) > 0:  # If the returned value is not True (then it is an invalid tool_id)
        raise HTTPException(status_code=404,
                           
                            detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")

    agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
                                                              toolkit_ids=agent_config.toolkits)
    agent_config.tools.extend(agent_toolkit_tools)
    db_agent = Agent.create_agent_with_config(db, agent_config)

    # Update the agent_id of schedule before scheduling the agent
    agent_schedule = agent_config_schedule.schedule

    # Create a new agent schedule
    agent_schedule = AgentSchedule(
        agent_id=db_agent.id,
        start_time=agent_schedule.start_time,
        next_scheduled_time=agent_schedule.start_time,
        recurrence_interval=agent_schedule.recurrence_interval,
        expiry_date=agent_schedule.expiry_date,
        expiry_runs=agent_schedule.expiry_runs,
        current_runs=0,
        status="SCHEDULED"
    )

    agent_schedule.agent_id = db_agent.id
    db.session.add(agent_schedule)
    db.session.commit()

    if agent_schedule.id is None:
        raise HTTPException(status_code=500, detail="Failed to schedule agent")

    agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
    organisation = agent.get_agent_organisation(db.session)

    EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_config.name,
                                                                        'model': agent_config.model}, db_agent.id,
                                                      organisation.id if organisation else 0)

    db.session.commit()

    return {
        "id": db_agent.id,
        "name": db_agent.name,
        "contentType": "Agents",
        "schedule_id": agent_schedule.id
    }



@router.post("/stop/schedule", status_code=200)
def stop_schedule(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
    Stopping the scheduling for a given agent.

    Args:
        agent_id (int): Identifier of the Agent
        Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.
    """

    agent_to_delete = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
                                                             AgentSchedule.status == "SCHEDULED").first()
    if not agent_to_delete:
        raise HTTPException(status_code=404, detail="Schedule not found")
    agent_to_delete.status = "STOPPED"
    db.session.commit()


@router.put("/edit/schedule", status_code=200)
def edit_schedule(schedule: AgentScheduleInput,
                  Authorize: AuthJWT = Depends(check_auth)):
    """
    Edit the scheduling for a given agent.

    Args:
        agent_id (int): Identifier of the Agent
        schedule (AgentSchedule): New schedule data
        Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.
    """

    agent_to_edit = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == schedule.agent_id, AgentSchedule.status == "SCHEDULED").first()
                        
    if not agent_to_edit:
        raise HTTPException(status_code=404, detail="Schedule not found")

    # Update agent schedule with new data
    agent_to_edit.start_time = schedule.start_time
    agent_to_edit.next_scheduled_time = schedule.start_time
    agent_to_edit.recurrence_interval = schedule.recurrence_interval
    agent_to_edit.expiry_date = schedule.expiry_date
    agent_to_edit.expiry_runs = schedule.expiry_runs

    db.session.commit()


@router.get("/get/schedule_data/{agent_id}")
def get_schedule_data(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
    Get the scheduling data for a given agent.

    Args:
        agent_id (int): Identifier of the Agent

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.

    Returns:
        current_datetime (DateTime): Current Date and Time.
        recurrence_interval (String): Time interval for recurring schedule run.
        expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
        expiry_runs (Integer): The number of runs before the agent expires.
    """
    agent = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
                                                   AgentSchedule.status == "SCHEDULED").first()

    if not agent:
        raise HTTPException(status_code=404, detail="Agent Schedule not found")

    user_timezone = db.session.query(AgentConfiguration).filter(AgentConfiguration.key == "user_timezone",
                                                                AgentConfiguration.agent_id == agent_id).first()

    if user_timezone and user_timezone.value != "None":
        tzone = timezone(user_timezone.value)
    else:
        tzone = timezone('GMT')

    current_datetime = datetime.now(tzone).strftime("%d/%m/%Y %I:%M %p")

    return {
        "current_datetime": current_datetime,
        "start_date": agent.start_time.astimezone(tzone).strftime("%d %b %Y"),
        "start_time": agent.start_time.astimezone(tzone).strftime("%I:%M %p"),
        "recurrence_interval": agent.recurrence_interval if agent.recurrence_interval else None,
        "expiry_date": agent.expiry_date.astimezone(tzone).strftime("%d/%m/%Y") if agent.expiry_date else None,
        "expiry_runs": agent.expiry_runs if agent.expiry_runs != -1 else None
    }

  

这段代码中定义了几个与代理(Agent)调度相关的路由处理函数,涵盖了创建、停止、编辑和获取代理调度数据的功能。以下是每个函数的详细功能分析:

1. 创建并调度代理 1:

  • 功能:创建一个新的代理并为其设置调度。
  • 参数:
  • agent_config_schedule: 包含代理配置和调度信息的对象。
  • Authorize: 用于JWT认证。
  • 处理流程:
  • 验证项目存在性。
  • 验证工具的有效性。
  • 合并工具包中的工具。
  • 创建代理配置。
  • 创建新的代理调度。
  • 提交数据库事务。
  • 返回代理ID、名称、内容类型和调度ID。
  • 异常处理:
  • 如果项目不存在或工具无效,抛出404错误。
  • 如果调度失败,抛出500错误。

2. 停止代理调度 2:

  • 功能:停止给定代理的调度。
  • 参数:
  • agent_id: 代理的标识符。
  • Authorize: 用于JWT认证。
  • 处理流程:
  • 查询并验证指定代理的调度状态。
  • 更改调度状态为“STOPPED”。
  • 提交数据库事务。
  • 异常处理:
  • 如果调度未找到,抛出404错误。

3. 编辑代理调度 3:

  • 功能:编辑给定代理的调度信息。
  • 参数:
  • schedule: 包含新的调度数据的对象。
  • Authorize: 用于JWT认证。
  • 处理流程:
  • 查询并验证指定代理的调度状态。
  • 更新调度数据。
  • 提交数据库事务。
  • 异常处理:
  • 如果调度未找到,抛出404错误。

4. 获取代理调度数据 4:

  • 功能:获取指定代理的调度数据。
  • 参数:
  • agent_id: 代理的标识符。
  • Authorize: 用于JWT认证。
  • 处理流程:
  • 查询并验证指定代理的调度状态。
  • 根据用户时区调整时间显示。
  • 返回当前日期时间、开始日期时间、重复间隔、到期日期和到期运行次数。
  • 异常处理:
  • 如果调度未找到,抛出404错误。

这些函数共同支持代理的生命周期管理,包括创建、调度、停止和编辑调度,以及获取调度相关数据,确保代理的操作按预定计划执行。每个函数都具备适当的异常处理机制,确保在出现问题时能够提供清晰的错误信息。

 比如我的项目里有2个agent:

 

 

 

agent调度:

 

 

对应代码:

@router.get("/get/schedule_data/{agent_id}")
def get_schedule_data(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
    """
    Get the scheduling data for a given agent.

    Args:
        agent_id (int): Identifier of the Agent

    Raises:
        HTTPException (status_code=404): If the agent schedule is not found.

    Returns:
        current_datetime (DateTime): Current Date and Time.
        recurrence_interval (String): Time interval for recurring schedule run.
        expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
        expiry_runs (Integer): The number of runs before the agent expires.
    """

  

差不多就那回事了,也就是数据库的操作。。。

 

创建agent执行任务的关键函数,agent_execution.py

@router.post("/add", response_model=AgentExecutionOut, status_code=201)
def create_agent_execution(agent_execution: AgentExecutionIn,
                           Authorize: AuthJWT = Depends(check_auth)):
    """
    Create a new agent execution/run.

    Args:
        agent_execution (AgentExecution): The agent execution data.

    Returns:
        AgentExecution: The created agent execution.

    Raises:
        HTTPException (Status Code=404): If the agent is not found.
    """

    agent = db.session.query(Agent).filter(Agent.id == agent_execution.agent_id, Agent.is_deleted == False).first()
    if not agent:
        raise HTTPException(status_code=404, detail="Agent not found")

    start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)

    iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
                                                                start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1

    db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
                                        agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
                                        num_of_tokens=0,
                                        current_agent_step_id=start_step.id,
                                        iteration_workflow_step_id=iteration_step_id)

    agent_execution_configs = {
        "goal": agent_execution.goal,
        "instruction": agent_execution.instruction
    }

    agent_configs = db.session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_execution.agent_id).all()
    keys_to_exclude = ["goal", "instruction"]
    for agent_config in agent_configs:
        if agent_config.key not in keys_to_exclude:
            if agent_config.key == "toolkits":
                if agent_config.value:
                    toolkits = [int(item) for item in agent_config.value.strip('{}').split(',') if item.strip() and item != '[]']
                    agent_execution_configs[agent_config.key] = toolkits
                else:
                    agent_execution_configs[agent_config.key] = []
            elif agent_config.key == "constraints":
                if agent_config.value:
                    agent_execution_configs[agent_config.key] = agent_config.value
                else:
                    agent_execution_configs[agent_config.key] = []
            else:
                agent_execution_configs[agent_config.key] = agent_config.value

    db.session.add(db_agent_execution)
    db.session.commit()
    db.session.flush()

    #update status from CREATED to RUNNING
    db_agent_execution.status = "RUNNING"
    db.session.commit()

    AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
                                                                     agent_execution_configs=agent_execution_configs)

    organisation = agent.get_agent_organisation(db.session)
    agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id)

    EventHandler(session=db.session).create_event('run_created',
                                                  {'agent_execution_id': db_agent_execution.id,
                                                   'agent_execution_name':db_agent_execution.name},
                                                   agent_execution.agent_id,
                                                   organisation.id if organisation else 0)
    if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
        knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
        if knowledge_name is not None:
            EventHandler(session=db.session).create_event('knowledge_picked',
                                                        {'knowledge_name': knowledge_name,
                                                         'agent_execution_id': db_agent_execution.id},
                                                        agent_execution.agent_id,
                                                        organisation.id if organisation else 0)
    Models.api_key_from_configurations(session=db.session, organisation_id=organisation.id)
    if db_agent_execution.status == "RUNNING":
      execute_agent.delay(db_agent_execution.id, datetime.now())

    return db_agent_execution

 

因为前面已经看到了,本质上是在执行toolkit模板里任务。

 

 

agent执行:

 执行agent的过程中:

 

示例返回:

{
    "status": "RUNNING",
    "feeds": [
        {
            "role": "system",
            "feed": "You are SuperAGI an AI assistant to solve complex problems. Your decisions must always be made independently without seeking user assistance.\nPlay to your strengths as an LLM and pursue simple strategies with no legal complications.\nIf you have completed all your tasks or reached end state, make sure to use the \"finish\" tool.\n\nGOALS:\n1. Get the details on the product from the [filename]\n2. Write a creative product description for the product, with sale-worthy keywords, and describe benefits of this product.\n\n\n\n\nCONSTRAINTS:\n1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.\n2. Ensure the tool and args are as per current plan and reasoning\n3. Exclusively use the tools listed under \"TOOLS\"\n4. REMEMBER to format your response as JSON, using double quotes (\"\") around keys and string values, and commas (,) to separate items in arrays and objects. IMPORTANTLY, to use a JSON object as a string in another JSON object, you need to escape the double quotes.\n\n\n",
            "updated_at": "2024-06-04T08:12:39.040402",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        },
        {
            "role": "user",
            "feed": "Determine which next tool to use, and respond using the format specified above:",
            "updated_at": "2024-06-04T08:12:39.052769",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        },
        {
            "role": "assistant",
            "feed": "```json\n{\n    \"thoughts\": {\n        \"text\": \"I need to get the details on the product from the specified file to proceed with writing a creative product description.\",\n        \"reasoning\": \"The first step is to obtain the necessary information about the product, which is likely stored in the file mentioned in the goal.\",\n        \"plan\": \"- Read the file to get product details\\n- Analyze the extracted information\\n- Write a creative product description with sale-worthy keywords\\n- Highlight the benefits of the product\",\n        \"criticism\": \"I need to ensure that I extract all relevant details from the file to create a comprehensive and appealing product description.\",\n        \"speak\": \"I will start by reading the file to get the details on the product.\"\n    },\n    \"tool\": {\n        \"name\": \"Read File\",\n        \"args\": {\n            \"file_name\": \"filename\"\n        }\n    }\n}\n```",
            "updated_at": "2024-06-04T08:12:46.389796",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        },
        {
            "role": "system",
            "feed": "Tool Read File returned: Error1: File 'filename' not found., FileNotFoundError, args: {'file_name': 'filename'}",
            "updated_at": "2024-06-04T08:12:46.389798",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        },
        {
            "role": "assistant",
            "feed": "```json\n{\n    \"thoughts\": {\n        \"text\": \"The specified file 'filename' could not be found. I need to determine the correct file name or location to proceed.\",\n        \"reasoning\": \"Without the correct file name, I cannot access the product details needed to create the description.\",\n        \"plan\": \"- Verify if there's an alternative method to identify the file name\\n- If no alternative, seek clarification on the correct file name or location\",\n        \"criticism\": \"I need to ensure the file name is accurate before attempting to read it.\",\n        \"speak\": \"I need to verify the correct file name or location to access the product details.\"\n    },\n    \"tool\": {\n        \"name\": \"ThinkingTool\",\n        \"args\": {\n            \"task_description\": \"Determine if there's a way to identify the correct file name or location to read the product details.\"\n        }\n    }\n}\n```",
            "updated_at": "2024-06-04T08:12:56.107737",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        },
        {
            "role": "system",
            "feed": "Tool ThinkingTool returned: Error generating text: float() argument must be a string or a real number, not 'dict'",
            "updated_at": "2024-06-04T08:12:56.107738",
            "time_difference": {
                "years": 0,
                "months": 0,
                "days": 0,
                "hours": 0,
                "minutes": 0
            }
        }
    ],
    "permissions": [],
    "waiting_period": null,
    "errors": ""
}

 

好了,我们最后看下super AGI最核心的ReAct实现。代码目录:

superagi/agent (main)
$ ls
__init__.py agent_prompt_template.py output_parser.py tool_executor.py
__pycache__/ agent_tool_step_handler.py

prompts/  ==》存放各种prompt 见另外一个文章。

types/ agent_iteration_step_handler.py agent_workflow_step_wait_handler.py queue_step_handler.py workflow_seed.py

agent_message_builder.py common_types.py task_queue.py
agent_prompt_builder.py output_handler.py tool_builder.py

 

AI代理的实现,主要涉及到任务队列、工具构建、消息构建、输出处理等多个方面。以下是核心实现算法和流程的分析:

核心类和方法

1. AgentIterationStepHandler:

  • execute_step: 执行迭代工作流步骤。首先获取代理配置和执行实例,然后根据当前步骤的状态决定是否继续执行。如果有必要,它会构建工具,生成提示,发送消息,并处理响应。

2. AgentLlmMessageBuilder:

  • build_agent_messages: 构建用于LLM(大型语言模型)的代理消息。这包括处理历史信息,生成提示,并可能包括完成提示。

3. AgentPromptBuilder:

  • replace_main_variables 和 replace_task_based_variables: 这些方法用于构建和修改提示,将动态内容(如任务、工具、目标等)插入到提示模板中。

4. ToolOutputHandler:

  • handle: 处理来自工具的输出。这可能涉及解析输出、应用业务逻辑,并最终生成适合用户的响应。

流程

1. 初始化:

  • 代理通过 AgentIterationStepHandler 初始化,设置会话、模型、代理ID等。

2. 执行步骤:

  • execute_step 方法是核心,它根据当前的执行状态和配置决定如何进一步处理。这可能包括等待权限、处理输入和输出、以及与LLM交互。

3. 消息构建:

  • 使用 AgentLlmMessageBuilder 来构建发送给LLM的消息。这涉及到处理历史数据、构建提示和管理令牌限制。

4. 提示构建:

  • AgentPromptBuilder 用于动态构建和调整提示,根据当前的任务和代理配置插入必要的变量。

5. 工具处理:

  • ToolOutputHandler 处理工具的输出,将其转换为用户可以理解的格式,并处理任何必要的后续步骤。

6. 错误处理和日志:

  • 在整个执行过程中,错误处理和日志记录是必不可少的,以确保问题可以被追踪并且代理行为符合预期。

 

因为本质上是ReAct思路,说下细节,目标任务分解、调度和执行主要通过以下几个关键组件实现:

1. 工作流和任务队列管理

AgentWorkflowStep 和 TaskQueue

  • AgentWorkflowStep 类定义了工作流中的各个步骤。每个步骤可以是一个工具的执行、一个等待事件、或者是一个任务队列处理步骤。
  • TaskQueue 类管理任务队列,支持添加任务、获取任务、完成任务等功能。它使用Redis作为后端存储,以支持任务的持久化和状态管理。

2. 工具执行和输出处理

ToolExecutor 和 ToolOutputHandler

  • ToolExecutor 类负责执行具体的工具。它通过动态加载工具模块,并调用工具的执行方法,传入必要的参数。
  • ToolOutputHandler 类处理工具执行后的输出。它可能会根据输出进行解析、执行后续的业务逻辑处理,或者生成适合用户的响应。

3. 工作流执行

AgentWaitStepHandler 和 QueueStepHandler

  • AgentWaitStepHandler 类处理等待步骤。它会检查是否满足继续执行的条件,如等待时间的完成或外部事件的触发。
  • QueueStepHandler 类处理队列步骤。它会从任务队列中获取任务,执行任务,并根据任务的结果决定下一步的操作。

4. 工作流调度

  • 工作流的调度是通过在每个步骤执行完毕后,根据当前步骤的输出或状态决定下一步执行哪个步骤来实现的。
  • 每个步骤可以配置下一步的跳转逻辑,支持基于条件的跳转。例如,一个步骤可以根据执行结果是成功还是失败,决定跳转到不同的步骤。

5. 示例:任务队列处理

class QueueStepHandler:
    def execute_step(self):
        execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
        workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
        step_tool = AgentWorkflowStepTool.find_by_id(self.session, workflow_step.action_reference_id)
        task_queue = self._build_task_queue(step_tool)
 
        if not task_queue.get_status() or task_queue.get_status() == QueueStatus.COMPLETE.value:
            task_queue.set_status(QueueStatus.INITIATED.value)
 
        if task_queue.get_status() == QueueStatus.INITIATED.value:
            self._add_to_queue(task_queue, step_tool)
            execution.current_feed_group_id = "DEFAULT"
            task_queue.set_status(QueueStatus.PROCESSING.value)
 
        if not task_queue.get_tasks():
            task_queue.set_status(QueueStatus.COMPLETE.value)
            return "COMPLETE"
        self._consume_from_queue(task_queue)
        return "default"
 
 

在这个例子中,QueueStepHandler 类负责处理队列步骤。它首先检查任务队列的状态,然后根据状态添加任务到队列或从队列中消费任务。这个过程涉及到任务的获取、执行和状态更新。

总结

整个代码实现了一个复杂的AI代理工作流系统,通过类和方法的组合使用,实现了任务的分解、调度和执行。这些实现支持动态的工作流配置,能够根据不同的业务需求调整执行逻辑。 

 

注意:在agent代码片段中,并没有使用向量数据库。代码主要涉及到任务队列的管理、工作流的处理、以及与Redis数据库的交互来管理任务状态和队列。以下是一些关键点:

数据库和存储技术Redis

  • 在 TaskQueue 类中,使用了Redis来管理任务队列。Redis是一个键值存储系统,通常用于缓存和消息队列,但它不是一个向量数据库。
  • 代码中使用Redis来存储任务、完成的任务以及任务状态,这些操作都是通过标准的Redis命令实现的,如 lpush, lpop, lrange 等。

 

分析下

import warnings
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Tuple

class VectorEmbeddings(ABC):

    @abstractmethod
    def get_vector_embeddings_from_chunks(
        self,
        final_chunks: Any
    ):
        """ Returns embeddings for vector dbs from final chunks"""from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings

class Pinecone(VectorEmbeddings):

    def __init__(self, uuid, embeds, metadata):
        self.uuid = uuid
        self.embeds = embeds
        self.metadata = metadata
        
    def get_vector_embeddings_from_chunks(self):
        """ Returns embeddings for vector dbs from final chunks"""
        result = {}
        vectors = list(zip(self.uuid, self.embeds, self.metadata))
        result['vectors'] = vectors
        return resultfrom typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings

class Qdrant(VectorEmbeddings):

    def __init__(self, uuid, embeds, metadata):
        self.uuid = uuid
        self.embeds = embeds
        self.metadata = metadata

    def get_vector_embeddings_from_chunks(self):
        """ Returns embeddings for vector dbs from final chunks"""
        result = {}
        result['ids'] = self.uuid
        result['payload'] = self.metadata
        result['vectors'] = self.embeds

        return result
import pinecone
from typing import Optional
from pinecone import UnauthorizedException
from superagi.vector_embeddings.pinecone import Pinecone
from superagi.vector_embeddings.qdrant import Qdrant
from superagi.vector_embeddings.weaviate import Weaviate
from superagi.types.vector_store_types import VectorStoreType

class VectorEmbeddingFactory:

    @classmethod
    def build_vector_storage(cls, vector_store: VectorStoreType, chunk_json: Optional[dict] = None):
        """
        Get the vector embeddings from final chunks.
        Args:
            vector_store : The vector store name.
        Returns:
            The vector storage object
        """
        final_chunks = []
        uuid = []
        embeds = []
        metadata = []
        vector_store = VectorStoreType.get_vector_store_type(vector_store)
        if chunk_json is not None:
            for key in chunk_json.keys():
                final_chunks.append(chunk_json[key])

            for i in range(0, len(final_chunks)):
                uuid.append(final_chunks[i]["id"])
                embeds.append(final_chunks[i]["embeds"])
                data = {
                    'text': final_chunks[i]['text'],
                    'chunk': final_chunks[i]['chunk'],
                    'knowledge_name': final_chunks[i]['knowledge_name']
                }
                metadata.append(data)

        if vector_store == VectorStoreType.PINECONE:
            return Pinecone(uuid, embeds, metadata)

        if vector_store == VectorStoreType.QDRANT:
            return Qdrant(uuid, embeds, metadata)
        
        if vector_store == VectorStoreType.WEAVIATE:
            return Weaviate(uuid, embeds, metadata)from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings

class Weaviate(VectorEmbeddings):

    def __init__(self, uuid, embeds, metadata):
        self.uuid = uuid
        self.embeds = embeds
        self.metadata = metadata

    def get_vector_embeddings_from_chunks(self):
        """ Returns embeddings for vector dbs from final chunks"""

        return {'ids': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds}

  

 

功能分析:定义了一个处理向量嵌入的框架,主要用于不同的向量数据库。文件中包含了几个关键的部分:

1. VectorEmbeddings 类:这是一个抽象基类(ABC),定义了一个抽象方法 get_vector_embeddings_from_chunks,用于从数据块中获取向量嵌入。这个方法需要在子类中实现。

2. Pinecone、Qdrant、Weaviate 类:这些类都继承自 VectorEmbeddings,并实现了 get_vector_embeddings_from_chunks 方法。每个类针对不同的向量数据库(如 Pinecone、Qdrant、Weaviate)进行操作,处理方式略有不同,例如如何组织和返回数据。

3. VectorEmbeddingFactory 类:这个类提供了一个类方法 build_vector_storage,根据提供的向量存储类型和可能的数据块 JSON,构建相应的向量存储对象。这个方法首先解析 JSON 数据,然后根据指定的存储类型创建相应的向量存储对象实例。

总的来说,这个文件的功能是提供一个框架,允许用户根据不同的向量数据库类型,从数据块中提取和管理向量嵌入。

主要是给知识搜索使用的!!!

用在知识库中:

 

 

vector_embedding使用了VectorStore???

实际的确如此。

 

 

我们看下vector store的功能:一个名为 VectorStore 的抽象基类,以及多个实现了这个基类的具体存储类,用于处理和存储向量数据。这些类主要用于文本和元数据的向量化存储和检索,支持多种后端存储系统,如 ChromaDB、Pinecone、Qdrant 和 Redis 等。以下是文件中定义的主要组件和功能:

1. VectorStore (抽象基类):

  • 定义了一些抽象方法,如 add_texts, get_matching_text, get_index_stats, add_embeddings_to_vector_db, 和 delete_embeddings_from_vector_db,这些方法需要在子类中具体实现。

2. ChromaDB:

  • 实现了 VectorStore,提供了使用 ChromaDB 作为后端的向量存储功能。
  • 支持创建集合、添加文本、获取匹配文本、获取索引统计信息等功能。

3. Pinecone:

  • 实现了 VectorStore,使用 Pinecone 服务进行向量存储。
  • 提供了添加文本、获取匹配文本、获取索引统计信息、添加和删除向量等功能。

4. Qdrant:

  • 实现了 VectorStore,使用 Qdrant 服务进行向量存储。
  • 支持添加文本、获取匹配文本、获取索引统计信息、添加和删除向量等功能

5. Redis:

  • 实现了 VectorStore,使用 Redis 服务进行向量存储。
  • 提供了添加文本、获取匹配文本、创建索引、添加和删除向量等功能。

6. Weaviate:

  • 实现了 VectorStore,使用 Weaviate 服务进行向量存储。
  • 支持添加文本、获取匹配文本、获取索引统计信息、添加和删除向量等功能。

此外,文件中还包含了一些辅助函数和类,如 create_weaviate_client 用于创建 Weaviate 客户端,以及 VectorFactory 类用于根据配置创建不同类型的向量存储实例。

 

真正使用是在: superagi/controllers/vector_dbs.py 

 

在提供的文件 superagi/controllers/vector_dbs.py 中,vector_store 模块被用于连接和操作不同的向量数据库。这里主要通过 VectorFactory 类来创建和管理不同类型的向量存储实例。以下是具体的使用情况:

1. 连接向量数据库:

  • 在连接 Pinecone、Qdrant 和 Weaviate 向量数据库的路由处理函数中,VectorFactory.build_vector_storage 方法被用来创建对应的向量存储实例。这些实例用于进一步的操作,如获取索引统计信息和添加向量索引。

例如,连接 Pinecone 数据库的代码片段如下:

   vector_db_storage = VectorFactory.build_vector_storage("pinecone", collection, **db_creds)
   db_connect_for_index = vector_db_storage.get_index_stats()

2. 更新向量数据库:

  • 在更新向量数据库的路由处理函数中,同样使用 VectorFactory.build_vector_storage 方法来创建向量存储实例。这些实例用于获取索引统计信息,以便在更新过程中添加新的向量索引。

例如,更新向量数据库的代码片段如下:

   vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, index, **db_creds)
   vector_db_index_stats = vector_db_storage.get_index_stats()

这些代码片段展示了如何通过 VectorFactory 类来动态创建和管理不同类型的向量存储实例,以便执行特定的数据库操作,如连接、更新和获取索引统计信息。这种设计使得向量数据库的管理更加灵活和模块化。

 

 

 

 

 

posted @ 2024-06-04 16:50  bonelee  阅读(39)  评论(1编辑  收藏  举报