Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Patch azure support #140

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ If you're using Azure OpenAI, set these variables instead:
export AZURE_OPENAI_KEY = ...
export AZURE_OPENAI_ENDPOINT = ...
export AZURE_OPENAI_VERSION = ...

# set the below if you are using deployment ids
export AZURE_OPENAI_DEPLOYMENT = ...
export AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = ...

# then use the --use_azure_openai flag
memgpt --use_azure_openai
Expand Down
50 changes: 18 additions & 32 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@

from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
get_set_azure_env_vars,
)

import asyncio

app = typer.Typer()
Expand Down Expand Up @@ -187,6 +193,18 @@ async def main(
if debug:
logging.getLogger().setLevel(logging.DEBUG)

# Azure OpenAI support
if use_azure_openai:
configure_azure_support()
check_azure_embeddings()
else:
azure_vars = get_set_azure_env_vars()
if len(azure_vars) > 0:
print(
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😈 love it 😈

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😈

)
return

if any(
(
persona,
Expand Down Expand Up @@ -285,38 +303,6 @@ async def main(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)

# Azure OpenAI support
if use_azure_openai:
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if None in [
azure_openai_key,
azure_openai_endpoint,
azure_openai_version,
azure_openai_deployment,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
return

import openai

openai.api_type = "azure"
openai.api_key = azure_openai_key
openai.api_base = azure_openai_endpoint
openai.api_version = azure_openai_version
# deployment gets passed into chatcompletion
else:
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
print(
f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False"
)
return

if cfg.index:
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
Expand Down
80 changes: 74 additions & 6 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,26 @@ async def acompletions_with_backoff(**kwargs):

# OpenAI / Azure model
else:
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
if using_azure():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
else:
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]]
kwargs.pop("model")
return await openai.ChatCompletion.acreate(**kwargs)


@aretry_with_exponential_backoff
async def acreate_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.acreate w/ backoff"""
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
if using_azure():
azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
else:
kwargs["engine"] = kwargs["model"]
kwargs.pop("model")
return await openai.Embedding.acreate(**kwargs)


Expand All @@ -138,3 +146,63 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002")
response = await acreate_embedding_with_backoff(input=[text], model=model)
embedding = response["data"][0]["embedding"]
return embedding


MODEL_TO_AZURE_ENGINE = {
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
}


def get_set_azure_env_vars():
azure_env_variables = [
("AZURE_OPENAI_KEY", os.getenv("AZURE_OPENAI_KEY")),
("AZURE_OPENAI_ENDPOINT", os.getenv("AZURE_OPENAI_ENDPOINT")),
("AZURE_OPENAI_VERSION", os.getenv("AZURE_OPENAI_VERSION")),
("AZURE_OPENAI_DEPLOYMENT", os.getenv("AZURE_OPENAI_DEPLOYMENT")),
(
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
),
]
return [x for x in azure_env_variables if x[1] is not None]


def using_azure():
return len(get_set_azure_env_vars()) > 0


def configure_azure_support():
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
if None in [
azure_openai_key,
azure_openai_endpoint,
azure_openai_version,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
return

openai.api_type = "azure"
openai.api_key = azure_openai_key
openai.api_base = azure_openai_endpoint
openai.api_version = azure_openai_version
# deployment gets passed into chatcompletion


def check_azure_embeddings():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if (
azure_openai_deployment is not None
and azure_openai_embedding_deployment is None
):
raise ValueError(
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
)