Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: fix bugs for retrieving archival memory via REST API + tests #1122

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class PassageModel(Base):

metadata_ = Column(MutableJson)

# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())

def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"

Expand All @@ -165,6 +168,7 @@ def to_record(self):
data_source=self.data_source,
agent_id=self.agent_id,
metadata_=self.metadata_,
created_at=self.created_at,
)

"""Create database model for table_name"""
Expand Down Expand Up @@ -317,7 +321,7 @@ def get_all_cursor(
# get records
db_record_chunk = query.limit(limit).all()
if not db_record_chunk:
return None
return (None, [])
records = [record.to_record() for record in db_record_chunk]
next_cursor = db_record_chunk[-1].id
assert isinstance(next_cursor, uuid.UUID)
Expand Down Expand Up @@ -471,7 +475,6 @@ def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=F
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
print(upsert_stmt)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
Expand Down Expand Up @@ -549,7 +552,6 @@ def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=F
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
print(upsert_stmt)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
Expand Down
31 changes: 31 additions & 0 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
raise NotImplementedError

def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
"""Paginated get for the archival memory for an agent"""
raise NotImplementedError


class RESTClient(AbstractClient):
def __init__(
Expand Down Expand Up @@ -229,6 +235,19 @@ def detach_source(self, source_name: str, agent_id: uuid.UUID):
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()

def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
"""Paginated get for the archival memory for an agent"""
params = {"limit": limit}
if before:
params["before"] = str(before)
if after:
params["after"] = str(after)
response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to get archival memory: {response.text}"
return response.json()["archival_memory"]


class LocalClient(AbstractClient):
def __init__(
Expand Down Expand Up @@ -351,3 +370,15 @@ def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):

def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)

def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
_, archival_json_records = self.server.get_agent_archival_cursor(
user_id=self.user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
return archival_json_records
3 changes: 3 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def __init__(
doc_id: Optional[uuid.UUID] = None,
id: Optional[uuid.UUID] = None,
metadata_: Optional[dict] = {},
created_at: Optional[datetime] = None,
):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
Expand Down Expand Up @@ -335,6 +336,8 @@ def __init__(
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model

self.created_at = created_at if created_at is not None else datetime.now()

if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
Expand Down
9 changes: 5 additions & 4 deletions memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class UpdateAgentMemoryResponse(BaseModel):

class ArchivalMemoryObject(BaseModel):
# TODO move to models/pydantic_models, or inherent from data_types Record
id: str = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
id: uuid.UUID = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
contents: str = Field(..., description="The memory contents.")


Expand Down Expand Up @@ -107,7 +107,8 @@ def get_agent_archival_memory_all(
interface.clear()
archival_memories = server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
print("archival_memories:", archival_memories)
return GetAgentArchivalMemoryResponse(archival_memory=archival_memories)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_memories]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)

@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory(
Expand All @@ -131,8 +132,8 @@ def get_agent_archival_memory(
before=before,
limit=limit,
)
print(archival_json_records)
return GetAgentArchivalMemoryResponse(archival_json_records)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)

@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=InsertAgentArchivalMemoryResponse)
def insert_agent_archival_memory(
Expand Down
76 changes: 75 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from dotenv import load_dotenv

from tests.config import TestMGPTConfig

from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig
from .utils import wipe_config, wipe_memgpt_home


import pytest
import uuid
Expand All @@ -29,6 +37,63 @@ def run_server():
import uvicorn
from memgpt.server.rest_api.server import app

load_dotenv()

# Use os.getenv with a fallback to os.environ.get
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
assert db_url, "Missing PGVECTOR_TEST_DB_URL"

if os.getenv("OPENAI_API_KEY"):
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
),
)
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_dim=1024,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="vllm",
model_endpoint="https://api.memgpt.ai",
model="ehartford/dolphin-2.5-mixtral-8x7b",
),
)
credentials = MemGPTCredentials()

config.save()
credentials.save()

uvicorn.run(app, host="localhost", port=8283, log_level="info")


Expand Down Expand Up @@ -124,16 +189,25 @@ def test_sources(client, agent):
print("listed sources", sources)
assert len(sources) == 1

# check agent archival memory size
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
print(archival_memories)
assert len(archival_memories) == 0

# load a file into a source
filename = "CONTRIBUTING.md"
num_passages = 20
response = client.load_file_into_source(filename, source.id)
print(response)

# attach a source
# TODO: make sure things run in the right order
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)

# TODO: list archival memory
# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
print(archival_memories)
assert len(archival_memories) == num_passages

# detach the source
# TODO: add when implemented
Expand Down
Loading