From 9e1082366d113286bc063051fd76b4799791d943 Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Fri, 23 Apr 2021 15:27:02 -0700 Subject: [PATCH] feat: add reauth support to async user credentials (#738) --- google/oauth2/_client_async.py | 125 ++++--- google/oauth2/_credentials_async.py | 13 +- google/oauth2/_reauth_async.py | 320 ++++++++++++++++++ google/oauth2/reauth.py | 6 +- tests_async/oauth2/test__client_async.py | 67 ++-- tests_async/oauth2/test_credentials_async.py | 36 +- tests_async/oauth2/test_reauth_async.py | 328 +++++++++++++++++++ 7 files changed, 785 insertions(+), 110 deletions(-) create mode 100644 google/oauth2/_reauth_async.py create mode 100644 tests_async/oauth2/test_reauth_async.py diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index 4817ea40e..cf5121137 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -30,53 +30,16 @@ from six.moves import http_client from six.moves import urllib -from google.auth import _helpers from google.auth import exceptions from google.auth import jwt from google.oauth2 import _client as client -def _handle_error_response(response_body): - """"Translates an error response into an exception. - - Args: - response_body (str): The decoded response data. - - Raises: - google.auth.exceptions.RefreshError - """ - try: - error_data = json.loads(response_body) - error_details = "{}: {}".format( - error_data["error"], error_data.get("error_description") - ) - # If no details could be extracted, use the response data. - except (KeyError, ValueError): - error_details = response_body - - raise exceptions.RefreshError(error_details, response_body) - - -def _parse_expiry(response_data): - """Parses the expiry field from a response into a datetime. - - Args: - response_data (Mapping): The JSON-parsed response data. - - Returns: - Optional[datetime]: The expiration or ``None`` if no expiration was - specified. - """ - expires_in = response_data.get("expires_in", None) - - if expires_in is not None: - return _helpers.utcnow() + datetime.timedelta(seconds=expires_in) - else: - return None - - -async def _token_endpoint_request(request, token_uri, body): +async def _token_endpoint_request_no_throw( + request, token_uri, body, access_token=None, use_json=False +): """Makes a request to the OAuth 2.0 authorization server's token endpoint. + This function doesn't throw on response errors. Args: request (google.auth.transport.Request): A callable used to make @@ -84,16 +47,23 @@ async def _token_endpoint_request(request, token_uri, body): token_uri (str): The OAuth 2.0 authorizations server's token endpoint URI. body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. Returns: - Mapping[str, str]: The JSON-decoded response data. - - Raises: - google.auth.exceptions.RefreshError: If the token endpoint returned - an error. + Tuple(bool, Mapping[str, str]): A boolean indicating if the request is + successful, and a mapping for the JSON-decoded response data. """ - body = urllib.parse.urlencode(body).encode("utf-8") - headers = {"content-type": client._URLENCODED_CONTENT_TYPE} + if use_json: + headers = {"Content-Type": client._JSON_CONTENT_TYPE} + body = json.dumps(body).encode("utf-8") + else: + headers = {"Content-Type": client._URLENCODED_CONTENT_TYPE} + body = urllib.parse.urlencode(body).encode("utf-8") + + if access_token: + headers["Authorization"] = "Bearer {}".format(access_token) retry = 0 # retry to fetch token for maximum of two times if any internal failure @@ -126,8 +96,38 @@ async def _token_endpoint_request(request, token_uri, body): ): retry += 1 continue - _handle_error_response(response_body) + return response.status == http_client.OK, response_data + return response.status == http_client.OK, response_data + + +async def _token_endpoint_request( + request, token_uri, body, access_token=None, use_json=False +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + response_status_ok, response_data = await _token_endpoint_request_no_throw( + request, token_uri, body, access_token=access_token, use_json=use_json + ) + if not response_status_ok: + client._handle_error_response(response_data) return response_data @@ -163,7 +163,7 @@ async def jwt_grant(request, token_uri, assertion): new_exc = exceptions.RefreshError("No access token in response.", response_data) six.raise_from(new_exc, caught_exc) - expiry = _parse_expiry(response_data) + expiry = client._parse_expiry(response_data) return access_token, expiry, response_data @@ -210,7 +210,13 @@ async def id_token_jwt_grant(request, token_uri, assertion): async def refresh_grant( - request, token_uri, refresh_token, client_id, client_secret, scopes=None + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, ): """Implements the OAuth 2.0 refresh token grant. @@ -229,10 +235,11 @@ async def refresh_grant( scopes must be authorized for the refresh token. Useful if refresh token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The reauth Proof Token. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The - access token, new refresh token, expiration, and additional data + access token, new or current refresh token, expiration, and additional data returned by the token endpoint. Raises: @@ -249,16 +256,8 @@ async def refresh_grant( } if scopes: body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token response_data = await _token_endpoint_request(request, token_uri, body) - - try: - access_token = response_data["access_token"] - except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No access token in response.", response_data) - six.raise_from(new_exc, caught_exc) - - refresh_token = response_data.get("refresh_token", refresh_token) - expiry = _parse_expiry(response_data) - - return access_token, refresh_token, expiry, response_data + return client._handle_refresh_grant_response(response_data, refresh_token) diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py index eb3e97c08..b4878c543 100644 --- a/google/oauth2/_credentials_async.py +++ b/google/oauth2/_credentials_async.py @@ -34,7 +34,7 @@ from google.auth import _credentials_async as credentials from google.auth import _helpers from google.auth import exceptions -from google.oauth2 import _client_async as _client +from google.oauth2 import _reauth_async as reauth from google.oauth2 import credentials as oauth2_credentials @@ -66,23 +66,26 @@ async def refresh(self, request): refresh_token, expiry, grant_response, - ) = await _client.refresh_grant( + rapt_token, + ) = await reauth.refresh_grant( request, self._token_uri, self._refresh_token, self._client_id, self._client_secret, - self._scopes, + scopes=self._scopes, + rapt_token=self._rapt_token, ) self.token = access_token self.expiry = expiry self._refresh_token = refresh_token self._id_token = grant_response.get("id_token") + self._rapt_token = rapt_token - if self._scopes and "scopes" in grant_response: + if self._scopes and "scope" in grant_response: requested_scopes = frozenset(self._scopes) - granted_scopes = frozenset(grant_response["scopes"].split()) + granted_scopes = frozenset(grant_response["scope"].split()) scopes_requested_but_not_granted = requested_scopes - granted_scopes if scopes_requested_but_not_granted: raise exceptions.RefreshError( diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py new file mode 100644 index 000000000..09e076090 --- /dev/null +++ b/google/oauth2/_reauth_async.py @@ -0,0 +1,320 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that provides functions for handling rapt authentication. + +Reauth is a process of obtaining additional authentication (such as password, +security token, etc.) while refreshing OAuth 2.0 credentials for a user. + +Credentials that use the Reauth flow must have the reauth scope, +``https://www.googleapis.com/auth/accounts.reauth``. + +This module provides a high-level function for executing the Reauth process, +:func:`refresh_grant`, and lower-level helpers for doing the individual +steps of the reauth process. + +Those steps are: + +1. Obtaining a list of challenges from the reauth server. +2. Running through each challenge and sending the result back to the reauth + server. +3. Refreshing the access token using the returned rapt token. +""" + +import sys + +from six.moves import range + +from google.auth import exceptions +from google.oauth2 import _client +from google.oauth2 import _client_async +from google.oauth2 import challenges +from google.oauth2 import reauth + + +async def _get_challenges( + request, supported_challenge_types, access_token, requested_scopes=None +): + """Does initial request to reauth API to get the challenges. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + supported_challenge_types (Sequence[str]): list of challenge names + supported by the manager. + access_token (str): Access token with reauth scopes. + requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials. + + Returns: + dict: The response from the reauth API. + """ + body = {"supportedChallengeTypes": supported_challenge_types} + if requested_scopes: + body["oauthScopesForDomainPolicyLookup"] = requested_scopes + + return await _client_async._token_endpoint_request( + request, + reauth._REAUTH_API + ":start", + body, + access_token=access_token, + use_json=True, + ) + + +async def _send_challenge_result( + request, session_id, challenge_id, client_input, access_token +): + """Attempt to refresh access token by sending next challenge result. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + session_id (str): session id returned by the initial reauth call. + challenge_id (str): challenge id returned by the initial reauth call. + client_input: dict with a challenge-specific client input. For example: + ``{'credential': password}`` for password challenge. + access_token (str): Access token with reauth scopes. + + Returns: + dict: The response from the reauth API. + """ + body = { + "sessionId": session_id, + "challengeId": challenge_id, + "action": "RESPOND", + "proposalResponse": client_input, + } + + return await _client_async._token_endpoint_request( + request, + reauth._REAUTH_API + "/{}:continue".format(session_id), + body, + access_token=access_token, + use_json=True, + ) + + +async def _run_next_challenge(msg, request, access_token): + """Get the next challenge from msg and run it. + + Args: + msg (dict): Reauth API response body (either from the initial request to + https://reauth.googleapis.com/v2/sessions:start or from sending the + previous challenge response to + https://reauth.googleapis.com/v2/sessions/id:continue) + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + access_token (str): reauth access token + + Returns: + dict: The response from the reauth API. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed. + """ + for challenge in msg["challenges"]: + if challenge["status"] != "READY": + # Skip non-activated challenges. + continue + c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None) + if not c: + raise exceptions.ReauthFailError( + "Unsupported challenge type {0}. Supported types: {1}".format( + challenge["challengeType"], + ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())), + ) + ) + if not c.is_locally_eligible: + raise exceptions.ReauthFailError( + "Challenge {0} is not locally eligible".format( + challenge["challengeType"] + ) + ) + client_input = c.obtain_challenge_input(challenge) + if not client_input: + return None + return await _send_challenge_result( + request, + msg["sessionId"], + challenge["challengeId"], + client_input, + access_token, + ) + return None + + +async def _obtain_rapt(request, access_token, requested_scopes): + """Given an http request method and reauth access token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + access_token (str): reauth access token + requested_scopes (Sequence[str]): scopes required by the client application + + Returns: + str: The rapt token. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed + """ + msg = await _get_challenges( + request, + list(challenges.AVAILABLE_CHALLENGES.keys()), + access_token, + requested_scopes, + ) + + if msg["status"] == reauth._AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + for _ in range(0, reauth.RUN_CHALLENGE_RETRY_LIMIT): + if not ( + msg["status"] == reauth._CHALLENGE_REQUIRED + or msg["status"] == reauth._CHALLENGE_PENDING + ): + raise exceptions.ReauthFailError( + "Reauthentication challenge failed due to API error: {}".format( + msg["status"] + ) + ) + + if not reauth.is_interactive(): + raise exceptions.ReauthFailError( + "Reauthentication challenge could not be answered because you are not" + " in an interactive session." + ) + + msg = await _run_next_challenge(msg, request, access_token) + + if msg["status"] == reauth._AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + # If we got here it means we didn't get authenticated. + raise exceptions.ReauthFailError("Failed to obtain rapt token.") + + +async def get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=None +): + """Given an http request method and refresh_token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + client_id (str): client id to get access token for reauth scope. + client_secret (str): client secret for the client_id + refresh_token (str): refresh token to refresh access token + token_uri (str): uri to refresh access token + scopes (Optional(Sequence[str])): scopes required by the client application + + Returns: + str: The rapt token. + Raises: + google.auth.exceptions.RefreshError: If reauth failed. + """ + sys.stderr.write("Reauthentication required.\n") + + # Get access token for reauth. + access_token, _, _, _ = await _client_async.refresh_grant( + request=request, + client_id=client_id, + client_secret=client_secret, + refresh_token=refresh_token, + token_uri=token_uri, + scopes=[reauth._REAUTH_SCOPE], + ) + + # Get rapt token from reauth API. + rapt_token = await _obtain_rapt(request, access_token, requested_scopes=scopes) + + return rapt_token + + +async def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, +): + """Implements the reauthentication flow. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The rapt token for reauth. + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The + access token, new refresh token, expiration, the additional data + returned by the token endpoint, and the rapt token. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = { + "grant_type": _client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + + response_status_ok, response_data = await _client_async._token_endpoint_request_no_throw( + request, token_uri, body + ) + if ( + not response_status_ok + and response_data.get("error") == reauth._REAUTH_NEEDED_ERROR + and ( + response_data.get("error_subtype") + == reauth._REAUTH_NEEDED_ERROR_INVALID_RAPT + or response_data.get("error_subtype") + == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED + ) + ): + rapt_token = await get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=scopes + ) + body["rapt"] = rapt_token + ( + response_status_ok, + response_data, + ) = await _client_async._token_endpoint_request_no_throw( + request, token_uri, body + ) + + if not response_status_ok: + _client._handle_error_response(response_data) + refresh_response = _client._handle_refresh_grant_response( + response_data, refresh_token + ) + return refresh_response + (rapt_token,) diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py index d539d7c9e..d914fe9a7 100644 --- a/google/oauth2/reauth.py +++ b/google/oauth2/reauth.py @@ -296,9 +296,9 @@ def refresh_grant( rapt_token (Optional(str)): The rapt token for reauth. Returns: - Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The - access token, new refresh token, expiration, and additional data - returned by the token endpoint. + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The + access token, new refresh token, expiration, the additional data + returned by the token endpoint, and the rapt token. Raises: google.auth.exceptions.RefreshError: If the token endpoint returned diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py index 458937ac1..6e48c4590 100644 --- a/tests_async/oauth2/test__client_async.py +++ b/tests_async/oauth2/test__client_async.py @@ -29,34 +29,6 @@ from tests.oauth2 import test__client as test_client -def test__handle_error_response(): - response_data = json.dumps({"error": "help", "error_description": "I'm alive"}) - - with pytest.raises(exceptions.RefreshError) as excinfo: - _client._handle_error_response(response_data) - - assert excinfo.match(r"help: I\'m alive") - - -def test__handle_error_response_non_json(): - response_data = "Help, I'm alive" - - with pytest.raises(exceptions.RefreshError) as excinfo: - _client._handle_error_response(response_data) - - assert excinfo.match(r"Help, I\'m alive") - - -@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) -def test__parse_expiry(unused_utcnow): - result = _client._parse_expiry({"expires_in": 500}) - assert result == datetime.datetime.min + datetime.timedelta(seconds=500) - - -def test__parse_expiry_none(): - assert _client._parse_expiry({}) is None - - def make_request(response_data, status=http_client.OK): response = mock.AsyncMock(spec=["transport.Response"]) response.status = status @@ -82,7 +54,7 @@ async def test__token_endpoint_request(): request.assert_called_with( method="POST", url="http://example.com", - headers={"content-type": "application/x-www-form-urlencoded"}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, body="test=params".encode("utf-8"), ) @@ -90,6 +62,35 @@ async def test__token_endpoint_request(): assert result == {"test": "response"} +@pytest.mark.asyncio +async def test__token_endpoint_request_json(): + + request = make_request({"test": "response"}) + access_token = "access_token" + + result = await _client._token_endpoint_request( + request, + "http://example.com", + {"test": "params"}, + access_token=access_token, + use_json=True, + ) + + # Check request call + request.assert_called_with( + method="POST", + url="http://example.com", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer access_token", + }, + body=b'{"test": "params"}', + ) + + # Check result + assert result == {"test": "response"} + + @pytest.mark.asyncio async def test__token_endpoint_request_error(): request = make_request({}, status=http_client.BAD_REQUEST) @@ -218,7 +219,12 @@ async def test_refresh_grant(unused_utcnow): ) token, refresh_token, expiry, extra_data = await _client.refresh_grant( - request, "http://example.com", "refresh_token", "client_id", "client_secret" + request, + "http://example.com", + "refresh_token", + "client_id", + "client_secret", + rapt_token="rapt_token", ) # Check request call @@ -229,6 +235,7 @@ async def test_refresh_grant(unused_utcnow): "refresh_token": "refresh_token", "client_id": "client_id", "client_secret": "client_secret", + "rapt": "rapt_token", }, ) diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py index 5c883d614..99cf16f80 100644 --- a/tests_async/oauth2/test_credentials_async.py +++ b/tests_async/oauth2/test_credentials_async.py @@ -58,7 +58,7 @@ def test_default_state(self): assert credentials.client_id == self.CLIENT_ID assert credentials.client_secret == self.CLIENT_SECRET - @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -68,6 +68,7 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): token = "token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = {"id_token": mock.sentinel.id_token} + rapt_token = "rapt_token" refresh_grant.return_value = ( # Access token token, @@ -77,6 +78,8 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): expiry, # Extra data grant_response, + # Rapt token + rapt_token, ) request = mock.AsyncMock(spec=["transport.Request"]) @@ -93,12 +96,14 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_ID, self.CLIENT_SECRET, None, + None, ) # Check that the credentials have the token and expiry assert creds.token == token assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token + assert creds.rapt_token == rapt_token # Check that the credentials are valid (have a token and are not # expired) @@ -114,7 +119,7 @@ async def test_refresh_no_refresh_token(self): request.assert_not_called() - @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -127,6 +132,7 @@ async def test_credentials_with_scopes_requested_refresh_success( token = "token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = {"id_token": mock.sentinel.id_token} + rapt_token = "rapt_token" refresh_grant.return_value = ( # Access token token, @@ -136,6 +142,8 @@ async def test_credentials_with_scopes_requested_refresh_success( expiry, # Extra data grant_response, + # Rapt token + rapt_token, ) request = mock.AsyncMock(spec=["transport.Request"]) @@ -146,6 +154,7 @@ async def test_credentials_with_scopes_requested_refresh_success( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, scopes=scopes, + rapt_token="old_rapt_token", ) # Refresh credentials @@ -159,6 +168,7 @@ async def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + "old_rapt_token", ) # Check that the credentials have the token and expiry @@ -166,12 +176,13 @@ async def test_credentials_with_scopes_requested_refresh_success( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(scopes) + assert creds.rapt_token == rapt_token # Check that the credentials are valid (have a token and are not # expired.) assert creds.valid - @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -183,10 +194,8 @@ async def test_credentials_with_scopes_returned_refresh_success( scopes = ["email", "profile"] token = "token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) - grant_response = { - "id_token": mock.sentinel.id_token, - "scopes": " ".join(scopes), - } + grant_response = {"id_token": mock.sentinel.id_token, "scope": " ".join(scopes)} + rapt_token = "rapt_token" refresh_grant.return_value = ( # Access token token, @@ -196,6 +205,8 @@ async def test_credentials_with_scopes_returned_refresh_success( expiry, # Extra data grant_response, + # Rapt token + rapt_token, ) request = mock.AsyncMock(spec=["transport.Request"]) @@ -219,6 +230,7 @@ async def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + None, ) # Check that the credentials have the token and expiry @@ -226,12 +238,13 @@ async def test_credentials_with_scopes_returned_refresh_success( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(scopes) + assert creds.rapt_token == rapt_token # Check that the credentials are valid (have a token and are not # expired.) assert creds.valid - @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -246,8 +259,9 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = { "id_token": mock.sentinel.id_token, - "scopes": " ".join(scopes_returned), + "scope": " ".join(scopes_returned), } + rapt_token = "rapt_token" refresh_grant.return_value = ( # Access token token, @@ -257,6 +271,8 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( expiry, # Extra data grant_response, + # Rapt token + rapt_token, ) request = mock.AsyncMock(spec=["transport.Request"]) @@ -267,6 +283,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, scopes=scopes, + rapt_token=None, ) # Refresh credentials @@ -283,6 +300,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + None, ) # Check that the credentials have the token and expiry diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py new file mode 100644 index 000000000..f144d89f5 --- /dev/null +++ b/tests_async/oauth2/test_reauth_async.py @@ -0,0 +1,328 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import mock +import pytest + +from google.auth import exceptions +from google.oauth2 import _reauth_async +from google.oauth2 import reauth + + +MOCK_REQUEST = mock.AsyncMock(spec=["transport.Request"]) +CHALLENGES_RESPONSE_TEMPLATE = { + "status": "CHALLENGE_REQUIRED", + "sessionId": "123", + "challenges": [ + { + "status": "READY", + "challengeId": 1, + "challengeType": "PASSWORD", + "securityKey": {}, + } + ], +} +CHALLENGES_RESPONSE_AUTHENTICATED = { + "status": "AUTHENTICATED", + "sessionId": "123", + "encodedProofOfReauthToken": "new_rapt_token", +} + + +class MockChallenge(object): + def __init__(self, name, locally_eligible, challenge_input): + self.name = name + self.is_locally_eligible = locally_eligible + self.challenge_input = challenge_input + + def obtain_challenge_input(self, metadata): + return self.challenge_input + + +@pytest.mark.asyncio +async def test__get_challenges(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request" + ) as mock_token_endpoint_request: + await _reauth_async._get_challenges(MOCK_REQUEST, ["SAML"], "token") + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + ":start", + {"supportedChallengeTypes": ["SAML"]}, + access_token="token", + use_json=True, + ) + + +@pytest.mark.asyncio +async def test__get_challenges_with_scopes(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request" + ) as mock_token_endpoint_request: + await _reauth_async._get_challenges( + MOCK_REQUEST, ["SAML"], "token", requested_scopes=["scope"] + ) + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + ":start", + { + "supportedChallengeTypes": ["SAML"], + "oauthScopesForDomainPolicyLookup": ["scope"], + }, + access_token="token", + use_json=True, + ) + + +@pytest.mark.asyncio +async def test__send_challenge_result(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request" + ) as mock_token_endpoint_request: + await _reauth_async._send_challenge_result( + MOCK_REQUEST, "123", "1", {"credential": "password"}, "token" + ) + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + "/123:continue", + { + "sessionId": "123", + "challengeId": "1", + "action": "RESPOND", + "proposalResponse": {"credential": "password"}, + }, + access_token="token", + use_json=True, + ) + + +@pytest.mark.asyncio +async def test__run_next_challenge_not_ready(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["challenges"][0]["status"] = "STATUS_UNSPECIFIED" + assert ( + await _reauth_async._run_next_challenge( + challenges_response, MOCK_REQUEST, "token" + ) + is None + ) + + +@pytest.mark.asyncio +async def test__run_next_challenge_not_supported(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["challenges"][0]["challengeType"] = "CHALLENGE_TYPE_UNSPECIFIED" + with pytest.raises(exceptions.ReauthFailError) as excinfo: + await _reauth_async._run_next_challenge( + challenges_response, MOCK_REQUEST, "token" + ) + assert excinfo.match(r"Unsupported challenge type CHALLENGE_TYPE_UNSPECIFIED") + + +@pytest.mark.asyncio +async def test__run_next_challenge_not_locally_eligible(): + mock_challenge = MockChallenge("PASSWORD", False, "challenge_input") + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + await _reauth_async._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + assert excinfo.match(r"Challenge PASSWORD is not locally eligible") + + +@pytest.mark.asyncio +async def test__run_next_challenge_no_challenge_input(): + mock_challenge = MockChallenge("PASSWORD", True, None) + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + assert ( + await _reauth_async._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + is None + ) + + +@pytest.mark.asyncio +async def test__run_next_challenge_success(): + mock_challenge = MockChallenge("PASSWORD", True, {"credential": "password"}) + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + with mock.patch( + "google.oauth2._reauth_async._send_challenge_result" + ) as mock_send_challenge_result: + await _reauth_async._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + mock_send_challenge_result.assert_called_with( + MOCK_REQUEST, "123", 1, {"credential": "password"}, "token" + ) + + +@pytest.mark.asyncio +async def test__obtain_rapt_authenticated(): + with mock.patch( + "google.oauth2._reauth_async._get_challenges", + return_value=CHALLENGES_RESPONSE_AUTHENTICATED, + ): + new_rapt_token = await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None) + assert new_rapt_token == "new_rapt_token" + + +@pytest.mark.asyncio +async def test__obtain_rapt_authenticated_after_run_next_challenge(): + with mock.patch( + "google.oauth2._reauth_async._get_challenges", + return_value=CHALLENGES_RESPONSE_TEMPLATE, + ): + with mock.patch( + "google.oauth2._reauth_async._run_next_challenge", + side_effect=[ + CHALLENGES_RESPONSE_TEMPLATE, + CHALLENGES_RESPONSE_AUTHENTICATED, + ], + ): + with mock.patch("google.oauth2.reauth.is_interactive", return_value=True): + assert ( + await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None) + == "new_rapt_token" + ) + + +@pytest.mark.asyncio +async def test__obtain_rapt_unsupported_status(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["status"] = "STATUS_UNSPECIFIED" + with mock.patch( + "google.oauth2._reauth_async._get_challenges", return_value=challenges_response + ): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None) + assert excinfo.match(r"API error: STATUS_UNSPECIFIED") + + +@pytest.mark.asyncio +async def test__obtain_rapt_not_interactive(): + with mock.patch( + "google.oauth2._reauth_async._get_challenges", + return_value=CHALLENGES_RESPONSE_TEMPLATE, + ): + with mock.patch("google.oauth2.reauth.is_interactive", return_value=False): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None) + assert excinfo.match(r"not in an interactive session") + + +@pytest.mark.asyncio +async def test__obtain_rapt_not_authenticated(): + with mock.patch( + "google.oauth2._reauth_async._get_challenges", + return_value=CHALLENGES_RESPONSE_TEMPLATE, + ): + with mock.patch("google.oauth2.reauth.RUN_CHALLENGE_RETRY_LIMIT", 0): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None) + assert excinfo.match(r"Reauthentication failed") + + +@pytest.mark.asyncio +async def test_get_rapt_token(): + with mock.patch( + "google.oauth2._client_async.refresh_grant", + return_value=("token", None, None, None), + ) as mock_refresh_grant: + with mock.patch( + "google.oauth2._reauth_async._obtain_rapt", return_value="new_rapt_token" + ) as mock_obtain_rapt: + assert ( + await _reauth_async.get_rapt_token( + MOCK_REQUEST, + "client_id", + "client_secret", + "refresh_token", + "token_uri", + ) + == "new_rapt_token" + ) + mock_refresh_grant.assert_called_with( + request=MOCK_REQUEST, + client_id="client_id", + client_secret="client_secret", + refresh_token="refresh_token", + token_uri="token_uri", + scopes=[reauth._REAUTH_SCOPE], + ) + mock_obtain_rapt.assert_called_with( + MOCK_REQUEST, "token", requested_scopes=None + ) + + +@pytest.mark.asyncio +async def test_refresh_grant_failed(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.return_value = (False, {"error": "Bad request"}) + with pytest.raises(exceptions.RefreshError) as excinfo: + await _reauth_async.refresh_grant( + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + scopes=["foo", "bar"], + rapt_token="rapt_token", + ) + assert excinfo.match(r"Bad request") + mock_token_request.assert_called_with( + MOCK_REQUEST, + "token_uri", + { + "grant_type": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + "refresh_token": "refresh_token", + "scope": "foo bar", + "rapt": "rapt_token", + }, + ) + + +@pytest.mark.asyncio +async def test_refresh_grant_success(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with mock.patch( + "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token" + ): + assert await _reauth_async.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) == ( + "access_token", + "refresh_token", + None, + {"access_token": "access_token"}, + "new_rapt_token", + )