From eca57b3928f56d0911da77fa9f20a315bde9b4e2 Mon Sep 17 00:00:00 2001 From: Dan Herlihy Date: Fri, 14 Feb 2025 13:17:17 -0500 Subject: [PATCH] Update default proxy construction to include SSL context --- httpx/_transports/default.py | 13 +++++++++++-- tests/client/test_proxies.py | 8 ++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index d5aa05ff23..9a237c4ad0 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -118,6 +118,15 @@ def map_httpcore_exceptions() -> typing.Iterator[None]: raise mapped_exc(message) from exc +def create_proxy(proxy: ProxyTypes | None, ssl_context: ssl.SSLContext) -> Proxy | None: + if isinstance(proxy, (str, URL)): + proxy_url = proxy if isinstance(proxy, URL) else URL(proxy) + if proxy_url.scheme == "https": + return Proxy(url=proxy_url, ssl_context=ssl_context) + return Proxy(url=proxy_url) + return proxy + + class ResponseStream(SyncByteStream): def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None: self._httpcore_stream = httpcore_stream @@ -149,8 +158,8 @@ def __init__( ) -> None: import httpcore - proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + proxy = create_proxy(proxy, ssl_context) if proxy is None: self._pool = httpcore.ConnectionPool( @@ -293,8 +302,8 @@ def __init__( ) -> None: import httpcore - proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + proxy = create_proxy(proxy, ssl_context) if proxy is None: self._pool = httpcore.AsyncConnectionPool( diff --git a/tests/client/test_proxies.py b/tests/client/test_proxies.py index 3e4090dcec..24db3a12b5 100644 --- a/tests/client/test_proxies.py +++ b/tests/client/test_proxies.py @@ -1,3 +1,5 @@ +import ssl + import httpcore import pytest @@ -263,3 +265,9 @@ def test_proxy_with_mounts(): transport = client._transport_for_url(httpx.URL("http://example.com")) assert transport == proxy_transport + + +def test_proxy_with_ssl_context(): + ssl_context = ssl.create_default_context() + proxy_transport = httpx.HTTPTransport(proxy="https://127.0.0.1", verify=ssl_context) + assert proxy_transport._pool._proxy_ssl_context == ssl_context # type: ignore[attr-defined]