From 8ff0de5f6c26c8778e24e57d6b7f449856357f81 Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Tue, 28 Mar 2023 11:59:02 -0700 Subject: [PATCH] feat: experimental service account iam endpoint flow for id token (#1258) * feat: experimental service account iam endpoint flow for id token * update * update * update test * address comment --- google/oauth2/_client.py | 42 ++++++++++++++++ google/oauth2/service_account.py | 74 +++++++++++++++++++++++++--- tests/oauth2/test__client.py | 44 +++++++++++++++++ tests/oauth2/test_service_account.py | 29 +++++++++++ 4 files changed, 183 insertions(+), 6 deletions(-) diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 428993646..74e769fa1 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -40,6 +40,10 @@ _JSON_CONTENT_TYPE = "application/json" _JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" _REFRESH_GRANT_TYPE = "refresh_token" +_IAM_IDTOKEN_ENDPOINT = ( + "https://iamcredentials.googleapis.com/v1/" + + "projects/-/serviceAccounts/{}:generateIdToken" +) def _handle_error_response(response_data, retryable_error): @@ -313,6 +317,44 @@ def jwt_grant(request, token_uri, assertion, can_retry=True): return access_token, expiry, response_data +def call_iam_generate_id_token_endpoint(request, signer_email, audience, access_token): + """Call iam.generateIdToken endpoint to get ID token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + signer_email (str): The signer email used to form the IAM + generateIdToken endpoint. + audience (str): The audience for the ID token. + access_token (str): The access token used to call the IAM endpoint. + + Returns: + Tuple[str, datetime]: The ID token and expiration. + """ + body = {"audience": audience, "includeEmail": "true"} + + response_data = _token_endpoint_request( + request, + _IAM_IDTOKEN_ENDPOINT.format(signer_email), + body, + access_token=access_token, + use_json=True, + ) + + try: + id_token = response_data["token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) + six.raise_from(new_exc, caught_exc) + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry + + def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but requests an OpenID Connect ID Token instead of an access token. diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 0989750db..618ab538b 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -554,6 +554,7 @@ def __init__( self._token_uri = token_uri self._target_audience = target_audience self._quota_project_id = quota_project_id + self._use_iam_endpoint = False if additional_claims is not None: self._additional_claims = additional_claims @@ -639,6 +640,31 @@ def with_target_audience(self, target_audience): quota_project_id=self.quota_project_id, ) + def _with_use_iam_endpoint(self, use_iam_endpoint): + """Create a copy of these credentials with the use_iam_endpoint value. + + Args: + use_iam_endpoint (bool): If True, IAM generateIdToken endpoint will + be used instead of the token_uri. Note that + iam.serviceAccountTokenCreator role is required to use the IAM + endpoint. The default value is False. This feature is currently + experimental and subject to change without notice. + + Returns: + google.auth.service_account.IDTokenCredentials: A new credentials + instance. + """ + cred = self.__class__( + self._signer, + service_account_email=self._service_account_email, + token_uri=self._token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + quota_project_id=self.quota_project_id, + ) + cred._use_iam_endpoint = use_iam_endpoint + return cred + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): return self.__class__( @@ -692,14 +718,50 @@ def _make_authorization_grant_assertion(self): return token + def _refresh_with_iam_endpoint(self, request): + """Use IAM generateIdToken endpoint to obtain an ID token. + + It works as follows: + + 1. First we create a self signed jwt with + https://www.googleapis.com/auth/iam being the scope. + + 2. Next we use the self signed jwt as the access token, and make a POST + request to IAM generateIdToken endpoint. The request body is: + { + "audience": self._target_audience, + "includeEmail": "true" + } + TODO: add "set_azp_to_email": "true" once it's ready from server side. + https://github.com/googleapis/google-auth-library-python/issues/1263 + + If the request is succesfully, it will return {"token":"the ID token"}, + and we can extract the ID token and compute its expiry. + """ + jwt_credentials = jwt.Credentials.from_signing_credentials( + self, + None, + additional_claims={"scope": "https://www.googleapis.com/auth/iam"}, + ) + jwt_credentials.refresh(request) + self.token, self.expiry = _client.call_iam_generate_id_token_endpoint( + request, + self.signer_email, + self._target_audience, + jwt_credentials.token.decode(), + ) + @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): - assertion = self._make_authorization_grant_assertion() - access_token, expiry, _ = _client.id_token_jwt_grant( - request, self._token_uri, assertion - ) - self.token = access_token - self.expiry = expiry + if self._use_iam_endpoint: + self._refresh_with_iam_endpoint(request) + else: + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = _client.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry @property def service_account_email(self): diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index ff3096057..4997d2401 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -305,6 +305,50 @@ def test_jwt_grant_no_access_token(): assert not excinfo.value.retryable +def test_call_iam_generate_id_token_endpoint(): + now = _helpers.utcnow() + id_token_expiry = _helpers.datetime_to_secs(now) + id_token = jwt.encode(SIGNER, {"exp": id_token_expiry}).decode("utf-8") + request = make_request({"token": id_token}) + + token, expiry = _client.call_iam_generate_id_token_endpoint( + request, "fake_email", "fake_audience", "fake_access_token" + ) + + assert ( + request.call_args[1]["url"] + == "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/fake_email:generateIdToken" + ) + assert request.call_args[1]["headers"]["Content-Type"] == "application/json" + assert ( + request.call_args[1]["headers"]["Authorization"] == "Bearer fake_access_token" + ) + response_body = json.loads(request.call_args[1]["body"]) + assert response_body["audience"] == "fake_audience" + assert response_body["includeEmail"] == "true" + + # Check result + assert token == id_token + # JWT does not store microseconds + now = now.replace(microsecond=0) + assert expiry == now + + +def test_call_iam_generate_id_token_endpoint_no_id_token(): + request = make_request( + { + # No access token. + "error": "no token" + } + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client.call_iam_generate_id_token_endpoint( + request, "fake_email", "fake_audience", "fake_access_token" + ) + assert excinfo.match("No ID token in response") + + def test_id_token_jwt_grant(): now = _helpers.utcnow() id_token_expiry = _helpers.datetime_to_secs(now) diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py index ed281fcfa..741027973 100644 --- a/tests/oauth2/test_service_account.py +++ b/tests/oauth2/test_service_account.py @@ -428,6 +428,7 @@ def test_from_service_account_info(self): assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"] assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"] assert credentials._target_audience == self.TARGET_AUDIENCE + assert not credentials._use_iam_endpoint def test_from_service_account_file(self): info = SERVICE_ACCOUNT_INFO.copy() @@ -440,6 +441,7 @@ def test_from_service_account_file(self): assert credentials._signer.key_id == info["private_key_id"] assert credentials._token_uri == info["token_uri"] assert credentials._target_audience == self.TARGET_AUDIENCE + assert not credentials._use_iam_endpoint def test_default_state(self): credentials = self.make_credentials() @@ -466,6 +468,11 @@ def test_with_target_audience(self): new_credentials = credentials.with_target_audience("https://new.example.com") assert new_credentials._target_audience == "https://new.example.com" + def test__with_use_iam_endpoint(self): + credentials = self.make_credentials() + new_credentials = credentials._with_use_iam_endpoint(True) + assert new_credentials._use_iam_endpoint + def test_with_quota_project(self): credentials = self.make_credentials() new_credentials = credentials.with_quota_project("project-foo") @@ -517,6 +524,28 @@ def test_refresh_success(self, id_token_jwt_grant): # expired) assert credentials.valid + @mock.patch( + "google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True + ) + def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint): + credentials = self.make_credentials() + credentials._use_iam_endpoint = True + token = "id_token" + call_iam_generate_id_token_endpoint.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + ) + request = mock.Mock() + credentials.refresh(request) + req, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[ + 0 + ] + assert req == request + assert signer_email == "service-account@example.com" + assert target_audience == "https://example.com" + decoded_access_token = jwt.decode(access_token, verify=False) + assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam" + @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True) def test_before_request_refreshes(self, id_token_jwt_grant): credentials = self.make_credentials()