Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Update REST API routes GET information for agents/humans/personas and store humans/personas in DB #1074

Merged
merged 8 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from memgpt.data_types import User, LLMConfig, EmbeddingConfig, Source
from memgpt.metadata import MetadataStore
from memgpt.server.utils import shorten_key_middle
from memgpt.models.pydantic_models import HumanModel, PersonaModel

app = typer.Typer()

Expand Down Expand Up @@ -761,20 +762,15 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all humans"""
table = PrettyTable()
table.field_names = ["Name", "Text"]
for human_file in utils.list_human_files():
text = open(human_file, "r").read()
name = os.path.basename(human_file).replace("txt", "")
table.add_row([name, text])
for human in ms.list_humans(user_id=user_id):
table.add_row([human.name, human.text])
print(table)
elif arg == ListChoice.personas:
"""List all personas"""
table = PrettyTable()
table.field_names = ["Name", "Text"]
for persona_file in utils.list_persona_files():
print(persona_file)
text = open(persona_file, "r").read()
name = os.path.basename(persona_file).replace(".txt", "")
table.add_row([name, text])
for persona in ms.list_personas(user_id=user_id):
table.add_row([persona.name, persona.text])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""
Expand Down Expand Up @@ -826,24 +822,16 @@ def add(
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
):
"""Add a person/human"""

config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
if option == "persona":
directory = os.path.join(MEMGPT_DIR, "personas")
ms.add_persona(PersonaModel(name=name, text=text, user_id=user_id))
elif option == "human":
directory = os.path.join(MEMGPT_DIR, "humans")
ms.add_human(HumanModel(name=name, text=text, user_id=user_id))
else:
raise ValueError(f"Unknown kind {option}")

if filename:
assert text is None, f"Cannot provide both filename and text"
# copy file to directory
shutil.copyfile(filename, os.path.join(directory, name))
if text:
assert filename is None, f"Cannot provide both filename and text"
# write text to file
with open(os.path.join(directory, name), "w", encoding="utf-8") as f:
f.write(text)


@app.command()
def delete(option: str, name: str):
Expand Down Expand Up @@ -886,6 +874,10 @@ def delete(option: str, name: str):
# metadata
ms.delete_agent(agent_id=agent.id)

elif option == "human":
ms.delete_human(name=name, user_id=user_id)
elif option == "persona":
ms.delete_persona(name=name, user_id=user_id)
else:
raise ValueError(f"Option {option} not implemented")

Expand Down
59 changes: 59 additions & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig

from memgpt.models.pydantic_models import PersonaModel, HumanModel

from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
Expand Down Expand Up @@ -318,6 +320,8 @@ def __init__(self, config: MemGPTConfig):
TokenModel.__table__,
PresetModel.__table__,
PresetSourceMapping.__table__,
HumanModel.__table__,
PersonaModel.__table__,
],
)
self.session_maker = sessionmaker(bind=self.engine)
Expand Down Expand Up @@ -599,3 +603,58 @@ def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID):
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
).delete()
session.commit()

@enforce_types
def add_human(self, human: HumanModel):
with self.session_maker() as session:
session.add(human)
session.commit()

@enforce_types
def add_persona(self, persona: PersonaModel):
with self.session_maker() as session:
session.add(persona)
session.commit()

@enforce_types
def get_human(self, name: str, user_id: uuid.UUID) -> str:
with self.session_maker() as session:
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]

@enforce_types
def get_persona(self, name: str, user_id: uuid.UUID) -> str:
with self.session_maker() as session:
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]

@enforce_types
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
with self.session_maker() as session:
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
return results

@enforce_types
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
with self.session_maker() as session:
# if user_id matches provided user_id or if user_id is None
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
return results

@enforce_types
def delete_human(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
session.commit()

@enforce_types
def delete_persona(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
session.commit()
43 changes: 42 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from enum import Enum
from pydantic import BaseModel, Field, Json
import uuid
from datetime import datetime
from sqlmodel import Field, SQLModel

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text, printd


class LLMConfigModel(BaseModel):
Expand All @@ -20,14 +25,50 @@ class EmbeddingConfigModel(BaseModel):
embedding_chunk_size: Optional[int] = 300


class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")


class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
description: str = Field(None, description="The description of the agent.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")

# timestamps
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")

# preset information
preset: str = Field(..., description="The preset used by the agent.")
persona: str = Field(..., description="The persona used by the agent.")
human: str = Field(..., description="The human used by the agent.")
functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")

# llm information
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")

# agent state
state: Optional[Dict] = Field(None, description="The state of the agent.")
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")


class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.")


class PersonaModel(SQLModel, table=True):
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
28 changes: 26 additions & 2 deletions memgpt/presets/presets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
import os
from memgpt.data_types import AgentState, Preset
from memgpt.interface import AgentInterface
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
from memgpt.utils import get_human_text, get_persona_text, printd
from memgpt.utils import get_human_text, get_persona_text, printd, list_human_files, list_persona_files
from memgpt.prompts import gpt_system
from memgpt.functions.functions import load_all_function_sets
from memgpt.metadata import MetadataStore
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.models.pydantic_models import HumanModel, PersonaModel

import uuid

Expand All @@ -15,15 +17,37 @@
preset_options = list(available_presets.keys())


def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
for persona_file in list_persona_files():
text = open(persona_file, "r").read()
name = os.path.basename(persona_file).replace(".txt", "")
if ms.get_persona(user_id=user_id, name=name) is not None:
printd(f"Persona '{name}' already exists for user '{user_id}'")
continue
persona = PersonaModel(name=name, text=text, user_id=user_id)
ms.add_persona(persona)
for human_file in list_human_files():
text = open(human_file, "r").read()
name = os.path.basename(human_file).replace(".txt", "")
if ms.get_human(user_id=user_id, name=name) is not None:
printd(f"Human '{name}' already exists for user '{user_id}'")
continue
human = HumanModel(name=name, text=text, user_id=user_id)
ms.add_human(human)


def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
"""Add the default presets to the metadata store"""
# make sure humans/personas added
add_default_humans_and_personas(user_id=user_id, ms=ms)

# add default presets
for preset_name in preset_options:
preset_config = available_presets[preset_name]
preset_system_prompt = preset_config["system_prompt"]
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)

print("PRESET", preset_name, user_id)
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue
Expand Down
44 changes: 34 additions & 10 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional

from memgpt.models.pydantic_models import AgentStateModel
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer

router = APIRouter()


class AgentConfigRequest(BaseModel):
class GetAgentRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")


Expand All @@ -23,9 +24,11 @@ class AgentRenameRequest(BaseModel):
agent_name: str = Field(..., description="New name for the agent.")


class AgentConfigResponse(BaseModel):
class GetAgentResponse(BaseModel):
# config: dict = Field(..., description="The agent configuration object.")
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")


def validate_agent_name(name: str) -> str:
Expand All @@ -48,7 +51,7 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
Expand All @@ -58,15 +61,36 @@ def get_agent_config(

This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
request = AgentConfigRequest(agent_id=agent_id)
request = GetAgentRequest(agent_id=agent_id)

agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)

interface.clear()
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
return AgentConfigResponse(agent_state=agent_state)

@router.patch("/agents/rename", tags=["agents"], response_model=AgentConfigResponse)
# return GetAgentResponse(agent_state=agent_state)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
),
last_run_at=None, # TODO
sources=attached_sources,
)

@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
Expand All @@ -87,7 +111,7 @@ def update_agent_name(
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return AgentConfigResponse(agent_state=agent_state)
return GetAgentResponse(agent_state=agent_state)

@router.delete("/agents", tags=["agents"])
def delete_agent(
Expand All @@ -97,7 +121,7 @@ def delete_agent(
"""
Delete an agent.
"""
request = AgentConfigRequest(agent_id=agent_id)
request = GetAgentRequest(agent_id=agent_id)

agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

Expand Down
1 change: 1 addition & 0 deletions memgpt/server/rest_api/agents/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def create_agent(
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
)
)
# return CreateAgentResponse(
Expand Down
Loading
Loading