From dfad66128c6ee7513e5565d39bc7b002055dd0d5 Mon Sep 17 00:00:00 2001 From: bojeil-google Date: Tue, 20 Jul 2021 10:43:13 -0700 Subject: [PATCH] fix: fallback to source creds expiration in downscoped tokens (#805) For downscoping CAB flow, the STS endpoint may not return the expiration field for certain source credentials. The generated downscoped token should always have the same expiration time as the source credentials. When no `expires_in` field is returned in the response, we can just get the expiration time from the source credentials. Co-authored-by: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> --- google/auth/downscoped.py | 12 ++++++++-- tests/test_downscoped.py | 46 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/google/auth/downscoped.py b/google/auth/downscoped.py index 800f2894c..96a4e6547 100644 --- a/google/auth/downscoped.py +++ b/google/auth/downscoped.py @@ -479,8 +479,16 @@ def refresh(self, request): additional_options=self._credential_access_boundary.to_json(), ) self.token = response_data.get("access_token") - lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) - self.expiry = now + lifetime + # For downscoping CAB flow, the STS endpoint may not return the expiration + # field for some flows. The generated downscoped token should always have + # the same expiration time as the source credentials. When no expires_in + # field is returned in the response, we can just get the expiration time + # from the source credentials. + if response_data.get("expires_in"): + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + else: + self.expiry = self._source_credentials.expiry @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): diff --git a/tests/test_downscoped.py b/tests/test_downscoped.py index ac60e5b00..795ec2942 100644 --- a/tests/test_downscoped.py +++ b/tests/test_downscoped.py @@ -80,10 +80,11 @@ class SourceCredentials(credentials.Credentials): - def __init__(self, raise_error=False): + def __init__(self, raise_error=False, expires_in=3600): super(SourceCredentials, self).__init__() self._counter = 0 self._raise_error = raise_error + self._expires_in = expires_in def refresh(self, request): if self._raise_error: @@ -93,7 +94,7 @@ def refresh(self, request): now = _helpers.utcnow() self._counter += 1 self.token = "ACCESS_TOKEN_{}".format(self._counter) - self.expiry = now + datetime.timedelta(seconds=3600) + self.expiry = now + datetime.timedelta(seconds=self._expires_in) def make_availability_condition(expression, title=None, description=None): @@ -539,6 +540,47 @@ def test_refresh(self, unused_utcnow): # Confirm source credentials called with the same request instance. wrapped_souce_cred_refresh.assert_called_with(request) + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_without_response_expires_in(self, unused_utcnow): + response = SUCCESS_RESPONSE.copy() + # Simulate the response is missing the expires_in field. + # The downscoped token expiration should match the source credentials + # expiration. + del response["expires_in"] + expected_expires_in = 1800 + # Simulate the source credentials generates a token with 1800 second + # expiration time. The generated downscoped token should have the same + # expiration time. + source_credentials = SourceCredentials(expires_in=expected_expires_in) + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=expected_expires_in + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": GRANT_TYPE, + "subject_token": "ACCESS_TOKEN_1", + "subject_token_type": SUBJECT_TOKEN_TYPE, + "requested_token_type": REQUESTED_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)), + } + request = self.make_mock_request(status=http_client.OK, data=response) + credentials = self.make_credentials(source_credentials=source_credentials) + + # Spy on calls to source credentials refresh to confirm the expected request + # instance is used. + with mock.patch.object( + source_credentials, "refresh", wraps=source_credentials.refresh + ) as wrapped_souce_cred_refresh: + credentials.refresh(request) + + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + # Confirm source credentials called with the same request instance. + wrapped_souce_cred_refresh.assert_called_with(request) + def test_refresh_token_exchange_error(self): request = self.make_mock_request( status=http_client.BAD_REQUEST, data=ERROR_RESPONSE