diff --git a/CHANGELOG.md b/CHANGELOG.md index a5fd574..67a100a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.22.0] - 2024-03-02 +### Changed +- Requires [`httpx`](https://www.python-httpx.org)==0.27.\* +- `httpx_auth.JsonTokenFileCache` and `httpx_auth.TokenMemoryCache` `get_token` method does not handle kwargs anymore, the `on_missing_token` callable does not expect any arguments anymore. + ## [0.21.0] - 2024-02-19 ### Added - Publicly expose `httpx_auth.SupportMultiAuth`, allowing multiple authentication support for every `httpx` authentication class that exists. @@ -245,7 +250,8 @@ Note that a few changes were made: ### Added - Placeholder for port of requests_auth to httpx -[Unreleased]: https://github.com/Colin-b/httpx_auth/compare/v0.21.0...HEAD +[Unreleased]: https://github.com/Colin-b/httpx_auth/compare/v0.22.0...HEAD +[0.22.0]: https://github.com/Colin-b/httpx_auth/compare/v0.21.0...v0.22.0 [0.21.0]: https://github.com/Colin-b/httpx_auth/compare/v0.20.0...v0.21.0 [0.20.0]: https://github.com/Colin-b/httpx_auth/compare/v0.19.0...v0.20.0 [0.19.0]: https://github.com/Colin-b/httpx_auth/compare/v0.18.0...v0.19.0 diff --git a/httpx_auth/_oauth2/authentication_responses_server.py b/httpx_auth/_oauth2/authentication_responses_server.py index 27d2ffb..fef2c80 100644 --- a/httpx_auth/_oauth2/authentication_responses_server.py +++ b/httpx_auth/_oauth2/authentication_responses_server.py @@ -166,7 +166,7 @@ def handle_timeout(self) -> None: raise TimeoutOccurred(self.timeout) -def request_new_grant(grant_details: GrantDetails) -> (str, str): +def request_new_grant(grant_details: GrantDetails) -> tuple[str, str]: """ Ask for a new OAuth2 grant. :return: A tuple (state, grant) diff --git a/httpx_auth/_oauth2/authorization_code.py b/httpx_auth/_oauth2/authorization_code.py index f9df1ab..334a44c 100644 --- a/httpx_auth/_oauth2/authorization_code.py +++ b/httpx_auth/_oauth2/authorization_code.py @@ -1,5 +1,5 @@ from hashlib import sha512 -from typing import Generator, Iterable, Union +from typing import Iterable, Union import httpx @@ -8,14 +8,14 @@ from httpx_auth._oauth2.browser import BrowserAuth from httpx_auth._oauth2.common import ( request_new_grant_with_post, - OAuth2, + OAuth2BaseAuth, _add_parameters, _pop_parameter, _get_query_parameter, ) -class OAuth2AuthorizationCode(httpx.Auth, SupportMultiAuth, BrowserAuth): +class OAuth2AuthorizationCode(OAuth2BaseAuth, SupportMultiAuth, BrowserAuth): """ Authorization Code Grant @@ -71,13 +71,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): BrowserAuth.__init__(self, kwargs) - self.header_name = kwargs.pop("header_name", None) or "Authorization" - self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" - if "{token}" not in self.header_value: - raise Exception("header_value parameter must contains {token}.") + header_name = kwargs.pop("header_name", None) or "Authorization" + header_value = kwargs.pop("header_value", None) or "Bearer {token}" self.token_field_name = kwargs.pop("token_field_name", None) or "access_token" - self.early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) + early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) username = kwargs.pop("username", None) password = kwargs.pop("password", None) @@ -99,11 +97,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): authorization_url_without_nonce, nonce = _pop_parameter( authorization_url_without_nonce, "nonce" ) - self.state = sha512( + state = sha512( authorization_url_without_nonce.encode("unicode_escape") ).hexdigest() custom_code_parameters = { - "state": self.state, + "state": state, "redirect_uri": self.redirect_uri, } if nonce: @@ -129,17 +127,14 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): self.refresh_data = {"grant_type": "refresh_token"} self.refresh_data.update(kwargs) - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - token = OAuth2.token_cache.get_token( - self.state, - early_expiry=self.early_expiry, - on_missing_token=self.request_new_token, - on_expired_token=self.refresh_token, + OAuth2BaseAuth.__init__( + self, + state, + early_expiry, + header_name, + header_value, + self.refresh_token, ) - request.headers[self.header_name] = self.header_value.format(token=token) - yield request def request_new_token(self) -> tuple: # Request code diff --git a/httpx_auth/_oauth2/authorization_code_pkce.py b/httpx_auth/_oauth2/authorization_code_pkce.py index 0216c66..314f736 100644 --- a/httpx_auth/_oauth2/authorization_code_pkce.py +++ b/httpx_auth/_oauth2/authorization_code_pkce.py @@ -1,7 +1,6 @@ import base64 import os from hashlib import sha256, sha512 -from typing import Generator import httpx @@ -10,13 +9,13 @@ from httpx_auth._oauth2.browser import BrowserAuth from httpx_auth._oauth2.common import ( request_new_grant_with_post, - OAuth2, + OAuth2BaseAuth, _add_parameters, _pop_parameter, ) -class OAuth2AuthorizationCodePKCE(httpx.Auth, SupportMultiAuth, BrowserAuth): +class OAuth2AuthorizationCodePKCE(OAuth2BaseAuth, SupportMultiAuth, BrowserAuth): """ Proof Key for Code Exchange @@ -72,13 +71,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): self.client = kwargs.pop("client", None) - self.header_name = kwargs.pop("header_name", None) or "Authorization" - self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" - if "{token}" not in self.header_value: - raise Exception("header_value parameter must contains {token}.") + header_name = kwargs.pop("header_name", None) or "Authorization" + header_value = kwargs.pop("header_value", None) or "Bearer {token}" self.token_field_name = kwargs.pop("token_field_name", None) or "access_token" - self.early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) + early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) # As described in https://tools.ietf.org/html/rfc6749#section-4.1.2 code_field_name = kwargs.pop("code_field_name", "code") @@ -98,11 +95,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): authorization_url_without_nonce, nonce = _pop_parameter( authorization_url_without_nonce, "nonce" ) - self.state = sha512( + state = sha512( authorization_url_without_nonce.encode("unicode_escape") ).hexdigest() custom_code_parameters = { - "state": self.state, + "state": state, "redirect_uri": self.redirect_uri, } if nonce: @@ -139,17 +136,9 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): self.refresh_data = {"grant_type": "refresh_token"} self.refresh_data.update(kwargs) - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - token = OAuth2.token_cache.get_token( - self.state, - early_expiry=self.early_expiry, - on_missing_token=self.request_new_token, - on_expired_token=self.refresh_token, + OAuth2BaseAuth.__init__( + self, state, early_expiry, header_name, header_value, self.refresh_token ) - request.headers[self.header_name] = self.header_value.format(token=token) - yield request def request_new_token(self) -> tuple: # Request code diff --git a/httpx_auth/_oauth2/client_credentials.py b/httpx_auth/_oauth2/client_credentials.py index caa1ab0..487b5fd 100644 --- a/httpx_auth/_oauth2/client_credentials.py +++ b/httpx_auth/_oauth2/client_credentials.py @@ -1,16 +1,16 @@ from hashlib import sha512 -from typing import Generator, Union, Iterable +from typing import Union, Iterable import httpx from httpx_auth._authentication import SupportMultiAuth from httpx_auth._oauth2.common import ( - OAuth2, + OAuth2BaseAuth, request_new_grant_with_post, _add_parameters, ) -class OAuth2ClientCredentials(httpx.Auth, SupportMultiAuth): +class OAuth2ClientCredentials(OAuth2BaseAuth, SupportMultiAuth): """ Client Credentials Grant @@ -49,13 +49,11 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs) if not self.client_secret: raise Exception("client_secret is mandatory.") - self.header_name = kwargs.pop("header_name", None) or "Authorization" - self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" - if "{token}" not in self.header_value: - raise Exception("header_value parameter must contains {token}.") + header_name = kwargs.pop("header_name", None) or "Authorization" + header_value = kwargs.pop("header_value", None) or "Bearer {token}" self.token_field_name = kwargs.pop("token_field_name", None) or "access_token" - self.early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) + early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) # Time is expressed in seconds self.timeout = int(kwargs.pop("timeout", None) or 60) @@ -70,18 +68,14 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs) self.data.update(kwargs) all_parameters_in_url = _add_parameters(self.token_url, self.data) - self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() - - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - token = OAuth2.token_cache.get_token( - self.state, - early_expiry=self.early_expiry, - on_missing_token=self.request_new_token, + state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() + + super().__init__( + state, + early_expiry, + header_name, + header_value, ) - request.headers[self.header_name] = self.header_value.format(token=token) - yield request def request_new_token(self) -> tuple: client = self.client or httpx.Client() diff --git a/httpx_auth/_oauth2/common.py b/httpx_auth/_oauth2/common.py index a94cde3..7670c1f 100644 --- a/httpx_auth/_oauth2/common.py +++ b/httpx_auth/_oauth2/common.py @@ -1,4 +1,5 @@ -from typing import Optional +import abc +from typing import Callable, Generator, Optional, Union from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode import httpx @@ -86,3 +87,41 @@ def request_new_grant_with_post( class OAuth2: token_cache = TokenMemoryCache() display = DisplaySettings() + + +class OAuth2BaseAuth(abc.ABC, httpx.Auth): + def __init__( + self, + state: str, + early_expiry: float, + header_name: str, + header_value: str, + refresh_token: Optional[Callable] = None, + ) -> None: + if "{token}" not in header_value: + raise Exception("header_value parameter must contains {token}.") + + self.state = state + self.early_expiry = early_expiry + self.header_name = header_name + self.header_value = header_value + self.refresh_token = refresh_token + + def auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + token = OAuth2.token_cache.get_token( + self.state, + early_expiry=self.early_expiry, + on_missing_token=self.request_new_token, + on_expired_token=self.refresh_token, + ) + self._update_user_request(request, token) + yield request + + @abc.abstractmethod + def request_new_token(self) -> Union[tuple[str, str], tuple[str, str, int]]: + pass # pragma: no cover + + def _update_user_request(self, request: httpx.Request, token: str) -> None: + request.headers[self.header_name] = self.header_value.format(token=token) diff --git a/httpx_auth/_oauth2/implicit.py b/httpx_auth/_oauth2/implicit.py index 6ec67b1..3ebc986 100644 --- a/httpx_auth/_oauth2/implicit.py +++ b/httpx_auth/_oauth2/implicit.py @@ -1,6 +1,5 @@ import uuid from hashlib import sha512 -from typing import Generator import httpx @@ -8,14 +7,14 @@ from httpx_auth._oauth2 import authentication_responses_server from httpx_auth._oauth2.browser import BrowserAuth from httpx_auth._oauth2.common import ( - OAuth2, + OAuth2BaseAuth, _add_parameters, _pop_parameter, _get_query_parameter, ) -class OAuth2Implicit(httpx.Auth, SupportMultiAuth, BrowserAuth): +class OAuth2Implicit(OAuth2BaseAuth, SupportMultiAuth, BrowserAuth): """ Implicit Grant @@ -62,10 +61,8 @@ def __init__(self, authorization_url: str, **kwargs): BrowserAuth.__init__(self, kwargs) - self.header_name = kwargs.pop("header_name", None) or "Authorization" - self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" - if "{token}" not in self.header_value: - raise Exception("header_value parameter must contains {token}.") + header_name = kwargs.pop("header_name", None) or "Authorization" + header_value = kwargs.pop("header_value", None) or "Bearer {token}" response_type = _get_query_parameter(self.authorization_url, "response_type") if response_type: @@ -82,7 +79,7 @@ def __init__(self, authorization_url: str, **kwargs): "id_token" if "id_token" == response_type else "access_token" ) - self.early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) + early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) authorization_url_without_nonce = _add_parameters( self.authorization_url, kwargs @@ -90,10 +87,10 @@ def __init__(self, authorization_url: str, **kwargs): authorization_url_without_nonce, nonce = _pop_parameter( authorization_url_without_nonce, "nonce" ) - self.state = sha512( + state = sha512( authorization_url_without_nonce.encode("unicode_escape") ).hexdigest() - custom_parameters = {"state": self.state, "redirect_uri": self.redirect_uri} + custom_parameters = {"state": state, "redirect_uri": self.redirect_uri} if nonce: custom_parameters["nonce"] = nonce grant_url = _add_parameters(authorization_url_without_nonce, custom_parameters) @@ -104,17 +101,16 @@ def __init__(self, authorization_url: str, **kwargs): self.redirect_uri_port, ) - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - token = OAuth2.token_cache.get_token( - self.state, - early_expiry=self.early_expiry, - on_missing_token=authentication_responses_server.request_new_grant, - grant_details=self.grant_details, + OAuth2BaseAuth.__init__( + self, + state, + early_expiry, + header_name, + header_value, ) - request.headers[self.header_name] = self.header_value.format(token=token) - yield request + + def request_new_token(self) -> tuple[str, str]: + return authentication_responses_server.request_new_grant(self.grant_details) class AzureActiveDirectoryImplicit(OAuth2Implicit): diff --git a/httpx_auth/_oauth2/resource_owner_password.py b/httpx_auth/_oauth2/resource_owner_password.py index 7c38419..2be25bf 100644 --- a/httpx_auth/_oauth2/resource_owner_password.py +++ b/httpx_auth/_oauth2/resource_owner_password.py @@ -1,16 +1,15 @@ from hashlib import sha512 -from typing import Generator import httpx from httpx_auth._authentication import SupportMultiAuth from httpx_auth._oauth2.common import ( - OAuth2, + OAuth2BaseAuth, request_new_grant_with_post, _add_parameters, ) -class OAuth2ResourceOwnerPasswordCredentials(httpx.Auth, SupportMultiAuth): +class OAuth2ResourceOwnerPasswordCredentials(OAuth2BaseAuth, SupportMultiAuth): """ Resource Owner Password Credentials Grant @@ -42,6 +41,7 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): Use it to provide a custom proxying rule for instance. :param kwargs: all additional authorization parameters that should be put as body parameters in the token URL. """ + self.token_url = token_url if not self.token_url: raise Exception("Token URL is mandatory.") @@ -52,13 +52,11 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): if not self.password: raise Exception("Password is mandatory.") - self.header_name = kwargs.pop("header_name", None) or "Authorization" - self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" - if "{token}" not in self.header_value: - raise Exception("header_value parameter must contains {token}.") + header_name = kwargs.pop("header_name", None) or "Authorization" + header_value = kwargs.pop("header_value", None) or "Bearer {token}" self.token_field_name = kwargs.pop("token_field_name", None) or "access_token" - self.early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) + early_expiry = float(kwargs.pop("early_expiry", None) or 30.0) # Time is expressed in seconds self.timeout = int(kwargs.pop("timeout", None) or 60) @@ -83,19 +81,16 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): self.refresh_data.update(kwargs) all_parameters_in_url = _add_parameters(self.token_url, self.data) - self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() - - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - token = OAuth2.token_cache.get_token( - self.state, - early_expiry=self.early_expiry, - on_missing_token=self.request_new_token, - on_expired_token=self.refresh_token, + state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() + + OAuth2BaseAuth.__init__( + self, + state, + early_expiry, + header_name, + header_value, + self.refresh_token, ) - request.headers[self.header_name] = self.header_value.format(token=token) - yield request def request_new_token(self) -> tuple: client = self.client or httpx.Client() diff --git a/httpx_auth/_oauth2/tokens.py b/httpx_auth/_oauth2/tokens.py index 41c7d44..0e8085a 100644 --- a/httpx_auth/_oauth2/tokens.py +++ b/httpx_auth/_oauth2/tokens.py @@ -114,7 +114,6 @@ def get_token( early_expiry: float = 30.0, on_missing_token=None, on_expired_token=None, - **on_missing_token_kwargs, ) -> str: """ Return the bearer token. @@ -126,7 +125,6 @@ def get_token( expired 30 seconds before real expiry by default. :param on_missing_token: function to call when token is expired or missing (returning token and expiry tuple) :param on_expired_token: function to call to refresh the token when it is expired - :param on_missing_token_kwargs: arguments of the on_missing_token function (key-value arguments) :return: the token :raise AuthenticationFailed: in case token cannot be retrieved. """ @@ -171,7 +169,7 @@ def get_token( logger.debug("Token cannot be found in cache.") if on_missing_token is not None: with self._forbid_concurrent_missing_token_function_call: - new_token = on_missing_token(**on_missing_token_kwargs) + new_token = on_missing_token() if len(new_token) == 2: # Bearer token state, token = new_token self._add_bearer_token(state, token) diff --git a/httpx_auth/version.py b/httpx_auth/version.py index 3ee084a..c27ed71 100644 --- a/httpx_auth/version.py +++ b/httpx_auth/version.py @@ -3,4 +3,4 @@ # Major should be incremented in case there is a breaking change. (eg: 2.5.8 -> 3.0.0) # Minor should be incremented in case there is an enhancement. (eg: 2.5.8 -> 2.6.0) # Patch should be incremented in case there is a bug fix. (eg: 2.5.8 -> 2.5.9) -__version__ = "0.21.0" +__version__ = "0.22.0" diff --git a/pyproject.toml b/pyproject.toml index d38316a..8b35b7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers=[ "Topic :: Software Development :: Build Tools", ] dependencies = [ - "httpx==0.26.*", + "httpx==0.27.*", ] dynamic = ["version"] @@ -45,7 +45,7 @@ testing = [ # Used to generate test tokens "pyjwt==2.*", # Used to mock httpx - "pytest_httpx==0.29.*", + "pytest_httpx==0.30.*", # Used to mock date and time "time-machine==2.*", # Used to check coverage