Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix(azure): azure_deployment use with realtime + non-deployment-based APIs #2154

Merged
merged 12 commits into from
Feb 28, 2025
67 changes: 56 additions & 11 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self) -> None:


class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None

@override
def _build_request(
self,
Expand All @@ -58,11 +61,29 @@ def _build_request(
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and not "/deployments" in str(self.base_url):
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"

return super()._build_request(options, retries_taken=retries_taken)

@override
def _prepare_url(self, url: str) -> httpx.URL:
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
and the API feature being called is **not** a deployments-based endpoint
(i.e. requires /deployments/deployment-name in the URL path).
"""
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_raw_path = (
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
)
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)

return merge_url

return super()._prepare_url(url)


class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
Expand Down Expand Up @@ -160,8 +181,8 @@ def __init__(
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
Expand Down Expand Up @@ -224,6 +245,8 @@ def __init__(
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
Expand Down Expand Up @@ -307,20 +330,30 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:

return options

def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers

if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")

url = realtime_url.copy_with(params={**query})
return url, auth_headers


class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
Expand Down Expand Up @@ -422,8 +455,8 @@ def __init__(
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
Expand Down Expand Up @@ -486,6 +519,8 @@ def __init__(
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
Expand Down Expand Up @@ -571,17 +606,27 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp

return options

async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers

if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")

url = realtime_url.copy_with(params={**query})
return url, auth_headers
36 changes: 18 additions & 18 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down Expand Up @@ -506,15 +506,15 @@ def __enter__(self) -> RealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down
Loading