From bbd0c0908d505513499989aa6072b7801d7d80b7 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sat, 27 Apr 2024 16:35:23 -0700 Subject: [PATCH 1/7] added assert --- memgpt/server/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index f27ec51414..a4131ca893 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -209,7 +209,6 @@ def __init__( # Initialize the connection to the DB self.config = MemGPTConfig.load() - print(f"server :: loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" @@ -261,7 +260,7 @@ def __init__( embedding_model=self.config.default_embedding_config.embedding_model, embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size, ) - assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config) + assert self.server_embedding_config.embedding_model is not None, self.server_embedding_config # Initialize the metadata store self.ms = MetadataStore(self.config) From cb2fa6d2acd7e2da9203d86e878d16d289790958 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sat, 27 Apr 2024 16:44:36 -0700 Subject: [PATCH 2/7] update dump --- configs/server_config.yaml | 13 ++++++------- memgpt/server/server.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/configs/server_config.yaml b/configs/server_config.yaml index dd86c4e291..96b8cef67a 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -7,7 +7,6 @@ human = basic model = gpt-4 model_endpoint = https://api.openai.com/v1 model_endpoint_type = openai -model_wrapper = null context_window = 8192 [embedding] @@ -19,18 +18,18 @@ embedding_chunk_size = 300 [archival_storage] type = postgres -path = /root/.memgpt/chroma -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +path = /Users/loaner/.memgpt/chroma +uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt [recall_storage] type = postgres -path = /root/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +path = /Users/loaner/.memgpt +uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt [metadata_storage] type = postgres -path = /root/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +path = /Users/loaner/.memgpt +uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt [version] memgpt_version = 0.3.12 diff --git a/memgpt/server/server.py b/memgpt/server/server.py index a4131ca893..6ac4a1b76f 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -260,7 +260,7 @@ def __init__( embedding_model=self.config.default_embedding_config.embedding_model, embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size, ) - assert self.server_embedding_config.embedding_model is not None, self.server_embedding_config + assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config) # Initialize the metadata store self.ms = MetadataStore(self.config) From c6ed10f971bacf482f3c9379099b858ec1c3dbb3 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sat, 27 Apr 2024 16:51:40 -0700 Subject: [PATCH 3/7] added config path dump --- memgpt/server/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 6ac4a1b76f..f27ec51414 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -209,6 +209,7 @@ def __init__( # Initialize the connection to the DB self.config = MemGPTConfig.load() + print(f"server :: loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" From cb7a2c4ae7919c49c9050753a961be9b860c6a8c Mon Sep 17 00:00:00 2001 From: cpacker Date: Sat, 27 Apr 2024 20:40:32 -0700 Subject: [PATCH 4/7] revert config diff, remove dupe configs dir --- configs/server_config.yaml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/configs/server_config.yaml b/configs/server_config.yaml index 96b8cef67a..dd86c4e291 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -7,6 +7,7 @@ human = basic model = gpt-4 model_endpoint = https://api.openai.com/v1 model_endpoint_type = openai +model_wrapper = null context_window = 8192 [embedding] @@ -18,18 +19,18 @@ embedding_chunk_size = 300 [archival_storage] type = postgres -path = /Users/loaner/.memgpt/chroma -uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt +path = /root/.memgpt/chroma +uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt [recall_storage] type = postgres -path = /Users/loaner/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt +path = /root/.memgpt +uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt [metadata_storage] type = postgres -path = /Users/loaner/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt +path = /root/.memgpt +uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt [version] memgpt_version = 0.3.12 From 236d8473dbb0a6b3d5efe098ce04256a2234746a Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 5 Apr 2024 15:21:17 -0700 Subject: [PATCH 5/7] upgrade llama-index-embeddings-huggingface package and fix bug with local embeddings --- memgpt/cli/cli_config.py | 89 +++++++++++++++------------------------- 1 file changed, 32 insertions(+), 57 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 86d3da36df..f92ce7d92f 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -277,9 +277,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) provider = "cohere" else: # local models - # backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] - backend_options = builtins.list(DEFAULT_ENDPOINTS.keys()) - # assert backend_options_old == backend_options, (backend_options_old, backend_options) + backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] default_model_endpoint_type = None if config.default_llm_config and config.default_llm_config.model_endpoint_type in backend_options: # set from previous config @@ -397,12 +395,8 @@ def get_model_options( else: # Attempt to do OpenAI endpoint style model fetching - # TODO support local auth with api-key header - if credentials.openllm_auth_type == "bearer_token": - api_key = credentials.openllm_key - else: - api_key = None - fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True) + # TODO support local auth + fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None) model_options = [obj["id"] for obj in fetched_model_options_response["data"]] # NOTE no filtering of local model options @@ -559,44 +553,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ raise KeyboardInterrupt else: # local models - - # ask about local auth - if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys - use_local_auth = True - local_auth_type = "bearer_token" - local_auth_key = questionary.password( - "Enter your Groq API key:", - ).ask() - if local_auth_key is None: - raise KeyboardInterrupt - credentials.openllm_auth_type = local_auth_type - credentials.openllm_key = local_auth_key - credentials.save() - else: - use_local_auth = questionary.confirm( - "Is your LLM endpoint authenticated? (default no)", - default=False, - ).ask() - if use_local_auth is None: - raise KeyboardInterrupt - if use_local_auth: - local_auth_type = questionary.select( - "What HTTP authentication method does your endpoint require?", - choices=SUPPORTED_AUTH_TYPES, - default=SUPPORTED_AUTH_TYPES[0], - ).ask() - if local_auth_type is None: - raise KeyboardInterrupt - local_auth_key = questionary.password( - "Enter your authentication key:", - ).ask() - if local_auth_key is None: - raise KeyboardInterrupt - # credentials = MemGPTCredentials.load() - credentials.openllm_auth_type = local_auth_type - credentials.openllm_key = local_auth_key - credentials.save() - # ollama also needs model type if model_endpoint_type == "ollama": default_model = ( @@ -617,7 +573,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ ) # vllm needs huggingface model tag - if model_endpoint_type in ["vllm", "groq"]: + if model_endpoint_type == "vllm": try: # Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint # + probably has custom model names @@ -672,6 +628,31 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model_wrapper is None: raise KeyboardInterrupt + # ask about local auth + use_local_auth = questionary.confirm( + "Is your LLM endpoint authenticated? (default no)", + default=False, + ).ask() + if use_local_auth is None: + raise KeyboardInterrupt + if use_local_auth: + local_auth_type = questionary.select( + "What HTTP authentication method does your endpoint require?", + choices=SUPPORTED_AUTH_TYPES, + default=SUPPORTED_AUTH_TYPES[0], + ).ask() + if local_auth_type is None: + raise KeyboardInterrupt + local_auth_key = questionary.password( + "Enter your authentication key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + # credentials = MemGPTCredentials.load() + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + # set: context_window if str(model) not in LLM_MAX_TOKENS: @@ -871,6 +852,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden embedding_endpoint = None embedding_model = "BAAI/bge-small-en-v1.5" embedding_dim = 384 + embedding_model = "BAAI/bge-small-en-v1.5" return embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model @@ -1087,7 +1069,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): """List all data sources""" # create table - table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"] + table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents @@ -1100,14 +1082,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None] table.add_row( - [ - source.name, - source.description, - source.embedding_model, - source.embedding_dim, - utils.format_datetime(source.created_at), - ",".join(agent_names), - ] + [source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)] ) print(table) From 74c715744a830d0552210412bca23c76f5df4b20 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sat, 6 Apr 2024 13:56:31 -0700 Subject: [PATCH 6/7] fix merge --- memgpt/cli/cli_config.py | 89 +++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index f92ce7d92f..86d3da36df 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -277,7 +277,9 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) provider = "cohere" else: # local models - backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] + # backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] + backend_options = builtins.list(DEFAULT_ENDPOINTS.keys()) + # assert backend_options_old == backend_options, (backend_options_old, backend_options) default_model_endpoint_type = None if config.default_llm_config and config.default_llm_config.model_endpoint_type in backend_options: # set from previous config @@ -395,8 +397,12 @@ def get_model_options( else: # Attempt to do OpenAI endpoint style model fetching - # TODO support local auth - fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None) + # TODO support local auth with api-key header + if credentials.openllm_auth_type == "bearer_token": + api_key = credentials.openllm_key + else: + api_key = None + fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True) model_options = [obj["id"] for obj in fetched_model_options_response["data"]] # NOTE no filtering of local model options @@ -553,6 +559,44 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ raise KeyboardInterrupt else: # local models + + # ask about local auth + if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys + use_local_auth = True + local_auth_type = "bearer_token" + local_auth_key = questionary.password( + "Enter your Groq API key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + else: + use_local_auth = questionary.confirm( + "Is your LLM endpoint authenticated? (default no)", + default=False, + ).ask() + if use_local_auth is None: + raise KeyboardInterrupt + if use_local_auth: + local_auth_type = questionary.select( + "What HTTP authentication method does your endpoint require?", + choices=SUPPORTED_AUTH_TYPES, + default=SUPPORTED_AUTH_TYPES[0], + ).ask() + if local_auth_type is None: + raise KeyboardInterrupt + local_auth_key = questionary.password( + "Enter your authentication key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + # credentials = MemGPTCredentials.load() + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + # ollama also needs model type if model_endpoint_type == "ollama": default_model = ( @@ -573,7 +617,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ ) # vllm needs huggingface model tag - if model_endpoint_type == "vllm": + if model_endpoint_type in ["vllm", "groq"]: try: # Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint # + probably has custom model names @@ -628,31 +672,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model_wrapper is None: raise KeyboardInterrupt - # ask about local auth - use_local_auth = questionary.confirm( - "Is your LLM endpoint authenticated? (default no)", - default=False, - ).ask() - if use_local_auth is None: - raise KeyboardInterrupt - if use_local_auth: - local_auth_type = questionary.select( - "What HTTP authentication method does your endpoint require?", - choices=SUPPORTED_AUTH_TYPES, - default=SUPPORTED_AUTH_TYPES[0], - ).ask() - if local_auth_type is None: - raise KeyboardInterrupt - local_auth_key = questionary.password( - "Enter your authentication key:", - ).ask() - if local_auth_key is None: - raise KeyboardInterrupt - # credentials = MemGPTCredentials.load() - credentials.openllm_auth_type = local_auth_type - credentials.openllm_key = local_auth_key - credentials.save() - # set: context_window if str(model) not in LLM_MAX_TOKENS: @@ -852,7 +871,6 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden embedding_endpoint = None embedding_model = "BAAI/bge-small-en-v1.5" embedding_dim = 384 - embedding_model = "BAAI/bge-small-en-v1.5" return embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model @@ -1069,7 +1087,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): """List all data sources""" # create table - table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"] + table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents @@ -1082,7 +1100,14 @@ def list(arg: Annotated[ListChoice, typer.Argument]): agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None] table.add_row( - [source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)] + [ + source.name, + source.description, + source.embedding_model, + source.embedding_dim, + utils.format_datetime(source.created_at), + ",".join(agent_names), + ] ) print(table) From aac1292229fd4c4bad2e641fe47cee4a312e88de Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sun, 28 Apr 2024 15:16:14 -0700 Subject: [PATCH 7/7] change args for create() call --- memgpt/agent.py | 4 +- memgpt/functions/function_sets/extras.py | 1 + memgpt/llm_api/llm_api_tools.py | 66 ++++++++++++------------ memgpt/memory.py | 3 +- 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index b02b6cd4a7..22aeeb40cb 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -424,7 +424,9 @@ def _get_ai_reply( """Get response from LLM API""" try: response = create( - agent_state=self.agent_state, + # agent_state=self.agent_state, + llm_config=self.agent_state.llm_config, + user_id=self.agent_state.user_id, messages=message_sequence, functions=self.functions, functions_python=self.functions_python, diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py index 9eb90988a9..025c3e6d47 100644 --- a/memgpt/functions/function_sets/extras.py +++ b/memgpt/functions/function_sets/extras.py @@ -31,6 +31,7 @@ def message_chatgpt(self, message: str): Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE), Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)), ] + # TODO: this will error without an LLMConfig response = create( model=MESSAGE_CHATGPT_FUNCTION_MODEL, messages=message_sequence, diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 220bebb1cd..59abc09e1e 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -1,13 +1,14 @@ import os import random import time +import uuid from typing import List, Optional, Union import requests from memgpt.constants import CLI_WARNING_PREFIX from memgpt.credentials import MemGPTCredentials -from memgpt.data_types import AgentState, Message +from memgpt.data_types import Message from memgpt.llm_api.anthropic import anthropic_chat_completions_request from memgpt.llm_api.azure_openai import ( MODEL_TO_AZURE_ENGINE, @@ -29,6 +30,7 @@ cast_message_to_subtype, ) from memgpt.models.chat_completion_response import ChatCompletionResponse +from memgpt.models.pydantic_models import LLMConfigModel from memgpt.streaming_interface import ( AgentChunkStreamingInterface, AgentRefreshStreamingInterface, @@ -135,8 +137,10 @@ def wrapper(*args, **kwargs): @retry_with_exponential_backoff def create( - agent_state: AgentState, + # agent_state: AgentState, + llm_config: LLMConfigModel, messages: List[Message], + user_id: uuid.UUID = None, # option UUID to associate request with functions: list = None, functions_python: list = None, function_call: str = "auto", @@ -152,7 +156,7 @@ def create( """Return response to chat completion with backoff""" from memgpt.utils import printd - printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}") + printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}") # TODO eventually refactor so that credentials are passed through credentials = MemGPTCredentials.load() @@ -162,26 +166,26 @@ def create( function_call = None # openai - if agent_state.llm_config.model_endpoint_type == "openai": + if llm_config.model_endpoint_type == "openai": # TODO do the same for Azure? - if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1": + if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise ValueError(f"OpenAI key is missing from MemGPT config file") if use_tool_naming: data = ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) else: data = ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], functions=functions, function_call=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) if stream: @@ -190,7 +194,7 @@ def create( stream_inferface, AgentRefreshStreamingInterface ), type(stream_inferface) return openai_chat_completions_process_stream( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, chat_completion_request=data, stream_inferface=stream_inferface, @@ -198,17 +202,15 @@ def create( else: data.stream = False return openai_chat_completions_request( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, chat_completion_request=data, ) # azure - elif agent_state.llm_config.model_endpoint_type == "azure": + elif llm_config.model_endpoint_type == "azure": azure_deployment = ( - credentials.azure_deployment - if credentials.azure_deployment is not None - else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model] + credentials.azure_deployment if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[llm_config.model] ) if use_tool_naming: data = dict( @@ -217,7 +219,7 @@ def create( messages=messages, tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) else: data = dict( @@ -226,7 +228,7 @@ def create( messages=messages, functions=functions, function_call=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) return azure_openai_chat_completions_request( resource_name=credentials.azure_endpoint, @@ -236,7 +238,7 @@ def create( data=data, ) - elif agent_state.llm_config.model_endpoint_type == "google_ai": + elif llm_config.model_endpoint_type == "google_ai": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Google AI API requests") @@ -254,7 +256,7 @@ def create( return google_ai_chat_completions_request( inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg, service_endpoint=credentials.google_ai_service_endpoint, - model=agent_state.llm_config.model, + model=llm_config.model, api_key=credentials.google_ai_key, # see structure of payload here: https://ai.google.dev/docs/function_calling data=dict( @@ -263,7 +265,7 @@ def create( ), ) - elif agent_state.llm_config.model_endpoint_type == "anthropic": + elif llm_config.model_endpoint_type == "anthropic": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") @@ -274,20 +276,20 @@ def create( tools = None return anthropic_chat_completions_request( - url=agent_state.llm_config.model_endpoint, + url=llm_config.model_endpoint, api_key=credentials.anthropic_key, data=ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, # tool_choice=function_call, - # user=str(agent_state.user_id), + # user=str(user_id), # NOTE: max_tokens is required for Anthropic API max_tokens=1024, # TODO make dynamic ), ) - elif agent_state.llm_config.model_endpoint_type == "cohere": + elif llm_config.model_endpoint_type == "cohere": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Cohere API requests") @@ -298,7 +300,7 @@ def create( tools = None return cohere_chat_completions_request( - # url=agent_state.llm_config.model_endpoint, + # url=llm_config.model_endpoint, url="https://api.cohere.ai/v1", # TODO api_key=os.getenv("COHERE_API_KEY"), # TODO remove chat_completion_request=ChatCompletionRequest( @@ -306,7 +308,7 @@ def create( messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - # user=str(agent_state.user_id), + # user=str(user_id), # NOTE: max_tokens is required for Anthropic API # max_tokens=1024, # TODO make dynamic ), @@ -315,16 +317,16 @@ def create( # local model else: return get_chat_completion( - model=agent_state.llm_config.model, + model=llm_config.model, messages=messages, functions=functions, functions_python=functions_python, function_call=function_call, - context_window=agent_state.llm_config.context_window, - endpoint=agent_state.llm_config.model_endpoint, - endpoint_type=agent_state.llm_config.model_endpoint_type, - wrapper=agent_state.llm_config.model_wrapper, - user=str(agent_state.user_id), + context_window=llm_config.context_window, + endpoint=llm_config.model_endpoint, + endpoint_type=llm_config.model_endpoint_type, + wrapper=llm_config.model_wrapper, + user=str(user_id), # hint first_message=first_message, # auth-related diff --git a/memgpt/memory.py b/memgpt/memory.py index eb2c03ca96..b07b0263c0 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -142,7 +142,8 @@ def summarize_messages( message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input)) response = create( - agent_state=agent_state, + llm_config=agent_state.llm_config, + user_id=agent_state.user_id, messages=message_sequence, )