Skip to content

Commit

Permalink
feat: use background tasks for processing uploaded files to REST API (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Apr 20, 2024
1 parent 04645e6 commit f74a32d
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 79 deletions.
27 changes: 23 additions & 4 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from requests.exceptions import RequestException
import uuid
from typing import Dict, List, Union, Optional, Tuple
import time

from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel, SourceModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel, SourceModel, JobModel, JobStatus
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
Expand Down Expand Up @@ -436,18 +437,36 @@ def delete_source(self, source_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete source: {response.text}"

def load_file_into_source(self, filename: str, source_id: uuid.UUID):
def get_job_status(self, job_id: uuid.UUID):
response = requests.get(f"{self.base_url}/api/sources/status/{str(job_id)}", headers=self.headers)
return JobModel(**response.json())

def load_file_into_source(self, filename: str, source_id: uuid.UUID, blocking=True):
"""Load {filename} and insert into source"""
files = {"file": open(filename, "rb")}

# create job
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
return UploadFileToSourceResponse(**response.json())
if response.status_code != 200:
raise ValueError(f"Failed to upload file to source: {response.text}")

job = JobModel(**response.json())
if blocking:
# wait until job is completed
while True:
job = self.get_job_status(job.id)
if job.status == JobStatus.completed:
break
elif job.status == JobStatus.failed:
raise ValueError(f"Job failed: {job.metadata}")
time.sleep(1)
return job

def create_source(self, name: str) -> Source:
"""Create a new source"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
response_json = response.json()
print("CREATE SOURCE", response_json, response.text)
response_obj = SourceModel(**response_json)
return Source(
id=uuid.UUID(response_obj.id),
Expand Down
16 changes: 3 additions & 13 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,9 @@ def __init__(
# functions
functions=None, # schema definitions ONLY (linked at runtime)
):
if name is None:
self.name = f"agent_{self.generate_agent_id()}"
else:
self.name = name

assert name, f"Agent name must be provided"
self.name = name

config = MemGPTConfig.load() # get default values
self.persona = config.persona if persona is None else persona
Expand Down Expand Up @@ -397,15 +396,6 @@ def __init__(
os.path.join(MEMGPT_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path
)

def generate_agent_id(self, length=6):
## random character based
# characters = string.ascii_lowercase + string.digits
# return ''.join(random.choices(characters, k=length))

# count based
agent_count = len(utils.list_agent_config_files())
return str(agent_count + 1)

def attach_data_source(self, data_source: str):
# TODO: add warning that only once source can be attached
# i.e. previous source will be overriden
Expand Down
2 changes: 1 addition & 1 deletion memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def load_data(

passages.append(passage)
embedding_to_document_name[hashable_embedding] = document_name
if len(passages) >= embedding_config.embedding_chunk_size:
if len(passages) >= 100:
# insert passages into passage store
passage_store.insert_many(passages)

Expand Down
33 changes: 31 additions & 2 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from memgpt.settings import settings
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
from memgpt.utils import enforce_types, printd
from memgpt.utils import enforce_types, printd, get_utc_time
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets

from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel
from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel, JobModel, JobStatus

from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
Expand Down Expand Up @@ -334,6 +334,7 @@ def __init__(self, config: MemGPTConfig):
HumanModel.__table__,
PersonaModel.__table__,
ToolModel.__table__,
JobModel.__table__,
],
)
self.session_maker = sessionmaker(bind=self.engine)
Expand Down Expand Up @@ -754,3 +755,31 @@ def delete_preset(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
session.commit()

# job related functions
def create_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.expunge_all()

def update_job_status(self, job_id: uuid.UUID, status: JobStatus):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
if status == JobStatus.COMPLETED:
session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
session.commit()

def update_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.refresh(job)

def get_job(self, job_id: uuid.UUID) -> Optional[JobModel]:
with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.id == job_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
21 changes: 20 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List, Optional, Dict, Literal, Type
from pydantic import BaseModel, Field, Json, ConfigDict
from enum import StrEnum
import uuid
import base64
import numpy as np
from datetime import datetime
from sqlmodel import Field, SQLModel
from sqlalchemy import JSON, Column, BINARY, TypeDecorator
from sqlalchemy_utils import ChoiceType
from sqlalchemy import JSON, Column, BINARY, TypeDecorator, String

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text, printd, get_utc_time
Expand Down Expand Up @@ -132,6 +134,23 @@ class SourceModel(SQLModel, table=True):
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")


class JobStatus(StrEnum):
created = "created"
running = "running"
completed = "completed"
failed = "failed"


class JobModel(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the job.", primary_key=True)
# status: str = Field(default="created", description="The status of the job.")
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus)))
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.")
metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.")


class PassageModel(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
Expand Down
1 change: 0 additions & 1 deletion memgpt/server/rest_api/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def get_api_keys(
"""
Get a list of all API keys for a user
"""
print("GET USERS", user_id)
try:
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
processed_tokens = [t.token for t in tokens]
Expand Down
2 changes: 0 additions & 2 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@
cd memgpt/server/rest_api
poetry run uvicorn server:app --reload
"""

config = MemGPTConfig.load()
for memory_type in ("archival", "recall", "metadata"):
setattr(config, f"{memory_type}_storage_uri", settings.pg_uri)
config.save()


interface: QueuingInterface = QueuingInterface()
server: SyncServer = SyncServer(default_interface=interface)

Expand Down
112 changes: 89 additions & 23 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from functools import partial
from typing import List, Optional

from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse

from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel, JobModel, JobStatus
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
Expand Down Expand Up @@ -57,14 +58,54 @@ class GetSourceDocumentsResponse(BaseModel):
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")


def load_file_to_source(server: SyncServer, user_id: uuid.UUID, source: Source, job_id: uuid.UUID, file: UploadFile):
# update job status
job = server.ms.get_job(job_id=job_id)
job.status = JobStatus.running
server.ms.update_job(job)

try:
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())

# read the file
connector = DirectoryConnector(input_files=[file_path])

# TODO: pre-compute total number of passages?

# load the data into the source via the connector
num_passages, num_documents = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
except Exception as e:
# job failed with error
error = str(e)
print(error)
job.status = JobStatus.failed
job.metadata_["error"] = error
server.ms.update_job(job)
# TODO: delete any associated passages/documents?
return 0, 0

# update job status
job.status = JobStatus.completed
job.metadata_["num_passages"] = num_passages
job.metadata_["num_documents"] = num_documents
print("job completed", job.metadata_, job.id)
server.ms.update_job(job)


def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_sources(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all data sources created by a user."""
"""
List all data sources created by a user.
"""
# Clear the interface
interface.clear()

Expand All @@ -81,7 +122,9 @@ async def create_source(
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a new data source."""
"""
Create a new data source.
"""
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
Expand All @@ -104,7 +147,9 @@ async def delete_source(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a data source."""
"""
Delete a data source.
"""
interface.clear()
try:
server.delete_source(source_id=source_id, user_id=user_id)
Expand All @@ -120,7 +165,9 @@ async def attach_source_to_agent(
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Attach a data source to an existing agent."""
"""
Attach a data source to an existing agent.
"""
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
Expand All @@ -141,41 +188,58 @@ async def detach_source_from_agent(
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Detach a data source from an existing agent."""
"""
Detach a data source from an existing agent.
"""
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)

@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=JobModel)
async def get_job_status(
job_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get the status of a job.
"""
job = server.ms.get_job(job_id=job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
return job

@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=JobModel)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: uuid.UUID,
background_tasks: BackgroundTasks,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Upload a file to a data source."""
"""
Upload a file to a data source.
"""
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)

# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())

# read the file
connector = DirectoryConnector(input_files=[file_path])
# create job
job = JobModel(user_id=user_id, metadata={"type": "embedding", "filename": file.filename, "source_id": source_id})
job_id = job.id
server.ms.create_job(job)

# load the data into the source via the connector
passage_count, document_count = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
# create background task
background_tasks.add_task(load_file_to_source, server, user_id, source, job_id, file)

# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
# return job information
job = server.ms.get_job(job_id=job_id)
return job

@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all passages associated with a data source."""
"""
List all passages associated with a data source.
"""
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return GetSourcePassagesResponse(passages=passages)

Expand All @@ -184,7 +248,9 @@ async def list_documents(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all documents associated with a data source."""
"""
List all documents associated with a data source.
"""
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return GetSourceDocumentsResponse(documents=documents)

Expand Down
Loading

0 comments on commit f74a32d

Please # to comment.