From 738978a3b7765638b06bb5bd19edc137b209a5f2 Mon Sep 17 00:00:00 2001 From: Shubham Naik Date: Tue, 13 Aug 2024 12:00:19 -0700 Subject: [PATCH] feat: create an admin return all agents route (#1620) --- memgpt/metadata.py | 7 +++++++ memgpt/server/rest_api/admin/agents.py | 21 +++++++++++++++++++++ memgpt/server/rest_api/auth/index.py | 7 ++++++- memgpt/server/rest_api/server.py | 5 +++++ memgpt/server/server.py | 17 +++++++++++------ 5 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 memgpt/server/rest_api/admin/agents.py diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 935cf4f25b..d9d88e242d 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -623,6 +623,13 @@ def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] + @enforce_types + def list_all_agents(self) -> List[AgentState]: + with self.session_maker() as session: + results = session.query(AgentModel).all() + + return [r.to_record() for r in results] + @enforce_types def list_sources(self, user_id: uuid.UUID) -> List[Source]: with self.session_maker() as session: diff --git a/memgpt/server/rest_api/admin/agents.py b/memgpt/server/rest_api/admin/agents.py new file mode 100644 index 0000000000..673a6225d2 --- /dev/null +++ b/memgpt/server/rest_api/admin/agents.py @@ -0,0 +1,21 @@ +from fastapi import APIRouter + +from memgpt.server.rest_api.agents.index import ListAgentsResponse +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer + +router = APIRouter() + + +def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface): + @router.get("/agents", tags=["agents"], response_model=ListAgentsResponse) + def get_all_agents(): + """ + Get a list of all agents in the database + """ + interface.clear() + agents_data = server.list_agents_legacy() + + return ListAgentsResponse(**agents_data) + + return router diff --git a/memgpt/server/rest_api/auth/index.py b/memgpt/server/rest_api/auth/index.py index 0c3727a695..30d748114b 100644 --- a/memgpt/server/rest_api/auth/index.py +++ b/memgpt/server/rest_api/auth/index.py @@ -1,3 +1,4 @@ +from typing import Optional from uuid import UUID from fastapi import APIRouter @@ -13,6 +14,7 @@ class AuthResponse(BaseModel): uuid: UUID = Field(..., description="UUID of the user") + is_admin: Optional[bool] = Field(None, description="Whether the user is an admin") class AuthRequest(BaseModel): @@ -29,10 +31,13 @@ def authenticate_user(request: AuthRequest) -> AuthResponse: Currently, this is a placeholder that simply returns a UUID placeholder """ interface.clear() + + is_admin = False if request.password != password: response = server.api_key_to_user(api_key=request.password) else: + is_admin = True response = server.authenticate_user() - return AuthResponse(uuid=response) + return AuthResponse(uuid=response, is_admin=is_admin) return router diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index a34be5f464..ecbd5fabf9 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -12,6 +12,7 @@ from starlette.middleware.cors import CORSMiddleware from memgpt.server.constants import REST_DEFAULT_PORT +from memgpt.server.rest_api.admin.agents import setup_agents_admin_router from memgpt.server.rest_api.admin.tools import setup_tools_index_router from memgpt.server.rest_api.admin.users import setup_admin_router from memgpt.server.rest_api.agents.command import setup_agents_command_router @@ -69,6 +70,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security ADMIN_PREFIX = "/admin" +ADMIN_API_PREFIX = "/api/admin" API_PREFIX = "/api" OPENAI_API_PREFIX = "/v1" @@ -89,6 +91,9 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security app.include_router(setup_admin_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)]) app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)]) +# /api/admin/agents endpoints +app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)]) + # /api/agents endpoints app.include_router(setup_agents_command_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_agents_config_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 3104866614..836364fe67 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -883,13 +883,18 @@ def list_agents( # TODO make return type pydantic def list_agents_legacy( self, - user_id: uuid.UUID, + user_id: Optional[uuid.UUID] = None, ) -> dict: """List all available agents to a user""" - if self.ms.get_user(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - agents_states = self.ms.list_agents(user_id=user_id) + if user_id is None: + agents_states = self.ms.list_all_agents() + else: + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + + agents_states = self.ms.list_agents(user_id=user_id) + agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states] # TODO add a get_message_obj_from_message_id(...) function @@ -900,7 +905,7 @@ def list_agents_legacy( for agent_state, return_dict in zip(agents_states, agents_states_dicts): # Get the agent object (loaded in memory) - memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) + memgpt_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id) # TODO remove this eventually when return type get pydanticfied # this is to add persona_name and human_name so that the columns in UI can populate @@ -918,7 +923,7 @@ def list_agents_legacy( # get tool info from agent state tools = [] for tool_name in agent_state.tools: - tool = self.ms.get_tool(tool_name, user_id) + tool = self.ms.get_tool(tool_name, agent_state.user_id) tools.append(tool) return_dict["tools"] = tools