Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: refactor create(..) call to LLMs to not require AgentState #1307

Merged
merged 7 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,9 @@ def _get_ai_reply(
"""Get response from LLM API"""
try:
response = create(
agent_state=self.agent_state,
# agent_state=self.agent_state,
llm_config=self.agent_state.llm_config,
user_id=self.agent_state.user_id,
messages=message_sequence,
functions=self.functions,
functions_python=self.functions_python,
Expand Down
1 change: 1 addition & 0 deletions memgpt/functions/function_sets/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def message_chatgpt(self, message: str):
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)),
]
# TODO: this will error without an LLMConfig
response = create(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
messages=message_sequence,
Expand Down
66 changes: 34 additions & 32 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import random
import time
import uuid
from typing import List, Optional, Union

import requests

from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import AgentState, Message
from memgpt.data_types import Message
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import (
MODEL_TO_AZURE_ENGINE,
Expand All @@ -29,6 +30,7 @@
cast_message_to_subtype,
)
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.pydantic_models import LLMConfigModel
from memgpt.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
Expand Down Expand Up @@ -135,8 +137,10 @@ def wrapper(*args, **kwargs):

@retry_with_exponential_backoff
def create(
agent_state: AgentState,
# agent_state: AgentState,
llm_config: LLMConfigModel,
messages: List[Message],
user_id: uuid.UUID = None, # option UUID to associate request with
functions: list = None,
functions_python: list = None,
function_call: str = "auto",
Expand All @@ -152,7 +156,7 @@ def create(
"""Return response to chat completion with backoff"""
from memgpt.utils import printd

printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}")
printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")

# TODO eventually refactor so that credentials are passed through
credentials = MemGPTCredentials.load()
Expand All @@ -162,26 +166,26 @@ def create(
function_call = None

# openai
if agent_state.llm_config.model_endpoint_type == "openai":
if llm_config.model_endpoint_type == "openai":
# TODO do the same for Azure?
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from MemGPT config file")
if use_tool_naming:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
else:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)

if stream:
Expand All @@ -190,25 +194,23 @@ def create(
stream_inferface, AgentRefreshStreamingInterface
), type(stream_inferface)
return openai_chat_completions_process_stream(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
stream_inferface=stream_inferface,
)
else:
data.stream = False
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
)

# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
elif llm_config.model_endpoint_type == "azure":
azure_deployment = (
credentials.azure_deployment
if credentials.azure_deployment is not None
else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model]
credentials.azure_deployment if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[llm_config.model]
)
if use_tool_naming:
data = dict(
Expand All @@ -217,7 +219,7 @@ def create(
messages=messages,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
else:
data = dict(
Expand All @@ -226,7 +228,7 @@ def create(
messages=messages,
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
return azure_openai_chat_completions_request(
resource_name=credentials.azure_endpoint,
Expand All @@ -236,7 +238,7 @@ def create(
data=data,
)

elif agent_state.llm_config.model_endpoint_type == "google_ai":
elif llm_config.model_endpoint_type == "google_ai":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Google AI API requests")

Expand All @@ -254,7 +256,7 @@ def create(
return google_ai_chat_completions_request(
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
service_endpoint=credentials.google_ai_service_endpoint,
model=agent_state.llm_config.model,
model=llm_config.model,
api_key=credentials.google_ai_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling
data=dict(
Expand All @@ -263,7 +265,7 @@ def create(
),
)

elif agent_state.llm_config.model_endpoint_type == "anthropic":
elif llm_config.model_endpoint_type == "anthropic":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")

Expand All @@ -274,20 +276,20 @@ def create(
tools = None

return anthropic_chat_completions_request(
url=agent_state.llm_config.model_endpoint,
url=llm_config.model_endpoint,
api_key=credentials.anthropic_key,
data=ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
# user=str(agent_state.user_id),
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
),
)

elif agent_state.llm_config.model_endpoint_type == "cohere":
elif llm_config.model_endpoint_type == "cohere":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Cohere API requests")

Expand All @@ -298,15 +300,15 @@ def create(
tools = None

return cohere_chat_completions_request(
# url=agent_state.llm_config.model_endpoint,
# url=llm_config.model_endpoint,
url="https://api.cohere.ai/v1", # TODO
api_key=os.getenv("COHERE_API_KEY"), # TODO remove
chat_completion_request=ChatCompletionRequest(
model="command-r-plus", # TODO
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
# user=str(agent_state.user_id),
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
# max_tokens=1024, # TODO make dynamic
),
Expand All @@ -315,16 +317,16 @@ def create(
# local model
else:
return get_chat_completion(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=messages,
functions=functions,
functions_python=functions_python,
function_call=function_call,
context_window=agent_state.llm_config.context_window,
endpoint=agent_state.llm_config.model_endpoint,
endpoint_type=agent_state.llm_config.model_endpoint_type,
wrapper=agent_state.llm_config.model_wrapper,
user=str(agent_state.user_id),
context_window=llm_config.context_window,
endpoint=llm_config.model_endpoint,
endpoint_type=llm_config.model_endpoint_type,
wrapper=llm_config.model_wrapper,
user=str(user_id),
# hint
first_message=first_message,
# auth-related
Expand Down
3 changes: 2 additions & 1 deletion memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def summarize_messages(
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))

response = create(
agent_state=agent_state,
llm_config=agent_state.llm_config,
user_id=agent_state.user_id,
messages=message_sequence,
)

Expand Down
Loading