From 97e7700da031bfd80b63b1a3d2abc29c500936ef Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Mon, 23 Mar 2020 11:53:25 -0700 Subject: [PATCH] feat: fetch id token from GCE metadata server (#462) feat: fetch id token from GCE metadata server --- google/auth/compute_engine/credentials.py | 163 +++++++++++++++++----- system_tests/test_compute_engine.py | 12 ++ tests/compute_engine/test_credentials.py | 139 ++++++++++++++++++ 3 files changed, 279 insertions(+), 35 deletions(-) diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index e35907abc..1927c26bd 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -125,18 +125,24 @@ class IDTokenCredentials(credentials.Credentials, credentials.Signing): These credentials relies on the default service account of a GCE instance. - In order for this to work, the GCE instance must have been started with + ID token can be requested from `GCE metadata server identity endpoint`_, IAM + token endpoint or other token endpoints you specify. If metadata server + identity endpoint is not used, the GCE instance must have been started with a service account that has access to the IAM Cloud API. + + .. _GCE metadata server identity endpoint: + https://cloud.google.com/compute/docs/instances/verifying-instance-identity """ def __init__( self, request, target_audience, - token_uri=_DEFAULT_TOKEN_URI, + token_uri=None, additional_claims=None, service_account_email=None, signer=None, + use_metadata_identity_endpoint=False, ): """ Args: @@ -154,29 +160,54 @@ def __init__( signer (google.auth.crypt.Signer): The signer used to sign JWTs. In case the signer is specified, the request argument will be ignored. + use_metadata_identity_endpoint (bool): Whether to use GCE metadata + identity endpoint. For backward compatibility the default value + is False. If set to True, ``token_uri``, ``additional_claims``, + ``service_account_email``, ``signer`` argument should not be set; + otherwise ValueError will be raised. + + Raises: + ValueError: + If ``use_metadata_identity_endpoint`` is set to True, and one of + ``token_uri``, ``additional_claims``, ``service_account_email``, + ``signer`` arguments is set. """ super(IDTokenCredentials, self).__init__() - if service_account_email is None: - sa_info = _metadata.get_service_account_info(request) - service_account_email = sa_info["email"] - self._service_account_email = service_account_email - - if signer is None: - signer = iam.Signer( - request=request, - credentials=Credentials(), - service_account_email=service_account_email, - ) - self._signer = signer - - self._token_uri = token_uri + self._use_metadata_identity_endpoint = use_metadata_identity_endpoint self._target_audience = target_audience - if additional_claims is not None: - self._additional_claims = additional_claims + if use_metadata_identity_endpoint: + if token_uri or additional_claims or service_account_email or signer: + raise ValueError( + "If use_metadata_identity_endpoint is set, token_uri, " + "additional_claims, service_account_email, signer arguments" + " must not be set" + ) + self._token_uri = None + self._additional_claims = None + self._signer = None + + if service_account_email is None: + sa_info = _metadata.get_service_account_info(request) + self._service_account_email = sa_info["email"] else: - self._additional_claims = {} + self._service_account_email = service_account_email + + if not use_metadata_identity_endpoint: + if signer is None: + signer = iam.Signer( + request=request, + credentials=Credentials(), + service_account_email=self._service_account_email, + ) + self._signer = signer + self._token_uri = token_uri or _DEFAULT_TOKEN_URI + + if additional_claims is not None: + self._additional_claims = additional_claims + else: + self._additional_claims = {} def with_target_audience(self, target_audience): """Create a copy of these credentials with the specified target @@ -190,14 +221,22 @@ def with_target_audience(self, target_audience): """ # since the signer is already instantiated, # the request is not needed - return self.__class__( - None, - service_account_email=self._service_account_email, - token_uri=self._token_uri, - target_audience=target_audience, - additional_claims=self._additional_claims.copy(), - signer=self.signer, - ) + if self._use_metadata_identity_endpoint: + return self.__class__( + None, + target_audience=target_audience, + use_metadata_identity_endpoint=True, + ) + else: + return self.__class__( + None, + service_account_email=self._service_account_email, + token_uri=self._token_uri, + target_audience=target_audience, + additional_claims=self._additional_claims.copy(), + signer=self.signer, + use_metadata_identity_endpoint=False, + ) def _make_authorization_grant_assertion(self): """Create the OAuth 2.0 assertion. @@ -228,22 +267,76 @@ def _make_authorization_grant_assertion(self): return token - @_helpers.copy_docstring(credentials.Credentials) + def _call_metadata_identity_endpoint(self, request): + """Request ID token from metadata identity endpoint. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the Compute Engine metadata + service can't be reached or if the instance has no credentials. + ValueError: If extracting expiry from the obtained ID token fails. + """ + try: + id_token = _metadata.get( + request, + "instance/service-accounts/default/identity?audience={}&format=full".format( + self._target_audience + ), + ) + except exceptions.TransportError as caught_exc: + new_exc = exceptions.RefreshError(caught_exc) + six.raise_from(new_exc, caught_exc) + + _, payload, _, _ = jwt._unverified_decode(id_token) + return id_token, payload["exp"] + def refresh(self, request): - assertion = self._make_authorization_grant_assertion() - access_token, expiry, _ = _client.id_token_jwt_grant( - request, self._token_uri, assertion - ) - self.token = access_token - self.expiry = expiry + """Refreshes the ID token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + ValueError: If extracting expiry from the obtained ID token fails. + """ + if self._use_metadata_identity_endpoint: + self.token, self.expiry = self._call_metadata_identity_endpoint(request) + else: + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = _client.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry @property @_helpers.copy_docstring(credentials.Signing) def signer(self): return self._signer - @_helpers.copy_docstring(credentials.Signing) def sign_bytes(self, message): + """Signs the given message. + + Args: + message (bytes): The message to sign. + + Returns: + bytes: The message's cryptographic signature. + + Raises: + ValueError: + Signer is not available if metadata identity endpoint is used. + """ + if self._use_metadata_identity_endpoint: + raise ValueError( + "Signer is not available if metadata identity endpoint is used" + ) return self._signer.sign(message) @property diff --git a/system_tests/test_compute_engine.py b/system_tests/test_compute_engine.py index 3217c958a..bcfdfd604 100644 --- a/system_tests/test_compute_engine.py +++ b/system_tests/test_compute_engine.py @@ -18,6 +18,7 @@ from google.auth import compute_engine from google.auth import _helpers from google.auth import exceptions +from google.auth import jwt from google.auth.compute_engine import _metadata @@ -48,3 +49,14 @@ def test_default(verify_refresh): assert project_id is not None assert isinstance(credentials, compute_engine.Credentials) verify_refresh(credentials) + + +def test_id_token_from_metadata(http_request): + credentials = compute_engine.IDTokenCredentials( + http_request, "target_audience", use_metadata_identity_endpoint=True + ) + credentials.refresh(http_request) + + _, payload, _, _ = jwt._unverified_decode(credentials.token) + assert payload["aud"] == "target_audience" + assert payload["exp"] == credentials.expiry diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index b861984e0..264235e49 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -25,6 +25,24 @@ from google.auth.compute_engine import credentials from google.auth.transport import requests +SAMPLE_ID_TOKEN_EXP = 1584393400 + +# header: {"alg": "RS256", "typ": "JWT", "kid": "1"} +# payload: {"iss": "issuer", "iat": 1584393348, "sub": "subject", +# "exp": 1584393400,"aud": "audience"} +SAMPLE_ID_TOKEN = ( + b"eyJhbGciOiAiUlMyNTYiLCAidHlwIjogIkpXVCIsICJraWQiOiAiMSJ9." + b"eyJpc3MiOiAiaXNzdWVyIiwgImlhdCI6IDE1ODQzOTMzNDgsICJzdWIiO" + b"iAic3ViamVjdCIsICJleHAiOiAxNTg0MzkzNDAwLCAiYXVkIjogImF1ZG" + b"llbmNlIn0." + b"OquNjHKhTmlgCk361omRo18F_uY-7y0f_AmLbzW062Q1Zr61HAwHYP5FM" + b"316CK4_0cH8MUNGASsvZc3VqXAqub6PUTfhemH8pFEwBdAdG0LhrNkU0H" + b"WN1YpT55IiQ31esLdL5q-qDsOPpNZJUti1y1lAreM5nIn2srdWzGXGs4i" + b"TRQsn0XkNUCL4RErpciXmjfhMrPkcAjKA-mXQm2fa4jmTlEZFqFmUlym1" + b"ozJ0yf5grjN6AslN4OGvAv1pS-_Ko_pGBS6IQtSBC6vVKCUuBfaqNjykg" + b"bsxbLa6Fp0SYeYwO8ifEnkRvasVpc1WTQqfRB2JCj5pTBDzJpIpFCMmnQ" +) + class TestCredentials(object): credentials = None @@ -238,6 +256,26 @@ def test_additional_claims(self, sign, get, utcnow): "foo": "bar", } + def test_token_uri(self): + request = mock.create_autospec(transport.Request, instance=True) + + self.credentials = credentials.IDTokenCredentials( + request=request, + signer=mock.Mock(), + service_account_email="foo@example.com", + target_audience="https://audience.com", + ) + assert self.credentials._token_uri == credentials._DEFAULT_TOKEN_URI + + self.credentials = credentials.IDTokenCredentials( + request=request, + signer=mock.Mock(), + service_account_email="foo@example.com", + target_audience="https://audience.com", + token_uri="https://example.com/token", + ) + assert self.credentials._token_uri == "https://example.com/token" + @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.utcfromtimestamp(0), @@ -469,3 +507,104 @@ def test_sign_bytes(self, sign, get): # The JWT token signature is 'signature' encoded in base 64: assert signature == b"signature" + + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) + def test_get_id_token_from_metadata(self, get, get_service_account_info): + get.return_value = SAMPLE_ID_TOKEN + get_service_account_info.return_value = {"email": "foo@example.com"} + + cred = credentials.IDTokenCredentials( + mock.Mock(), "audience", use_metadata_identity_endpoint=True + ) + cred.refresh(request=mock.Mock()) + + assert cred.token == SAMPLE_ID_TOKEN + assert cred.expiry == SAMPLE_ID_TOKEN_EXP + assert cred._use_metadata_identity_endpoint + assert cred._signer is None + assert cred._token_uri is None + assert cred._service_account_email == "foo@example.com" + assert cred._target_audience == "audience" + with pytest.raises(ValueError): + cred.sign_bytes(b"bytes") + + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + def test_with_target_audience_for_metadata(self, get_service_account_info): + get_service_account_info.return_value = {"email": "foo@example.com"} + + cred = credentials.IDTokenCredentials( + mock.Mock(), "audience", use_metadata_identity_endpoint=True + ) + cred = cred.with_target_audience("new_audience") + + assert cred._target_audience == "new_audience" + assert cred._use_metadata_identity_endpoint + assert cred._signer is None + assert cred._token_uri is None + assert cred._service_account_email == "foo@example.com" + + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) + def test_invalid_id_token_from_metadata(self, get, get_service_account_info): + get.return_value = "invalid_id_token" + get_service_account_info.return_value = {"email": "foo@example.com"} + + cred = credentials.IDTokenCredentials( + mock.Mock(), "audience", use_metadata_identity_endpoint=True + ) + + with pytest.raises(ValueError): + cred.refresh(request=mock.Mock()) + + @mock.patch( + "google.auth.compute_engine._metadata.get_service_account_info", autospec=True + ) + @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) + def test_transport_error_from_metadata(self, get, get_service_account_info): + get.side_effect = exceptions.TransportError("transport error") + get_service_account_info.return_value = {"email": "foo@example.com"} + + cred = credentials.IDTokenCredentials( + mock.Mock(), "audience", use_metadata_identity_endpoint=True + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + cred.refresh(request=mock.Mock()) + assert excinfo.match(r"transport error") + + def test_get_id_token_from_metadata_constructor(self): + with pytest.raises(ValueError): + credentials.IDTokenCredentials( + mock.Mock(), + "audience", + use_metadata_identity_endpoint=True, + token_uri="token_uri", + ) + with pytest.raises(ValueError): + credentials.IDTokenCredentials( + mock.Mock(), + "audience", + use_metadata_identity_endpoint=True, + signer=mock.Mock(), + ) + with pytest.raises(ValueError): + credentials.IDTokenCredentials( + mock.Mock(), + "audience", + use_metadata_identity_endpoint=True, + additional_claims={"key", "value"}, + ) + with pytest.raises(ValueError): + credentials.IDTokenCredentials( + mock.Mock(), + "audience", + use_metadata_identity_endpoint=True, + service_account_email="foo@example.com", + )