Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: making iam endpoint universe-aware #1604

Merged
merged 11 commits into from
Oct 19, 2024
13 changes: 6 additions & 7 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,19 @@
http_client.GATEWAY_TIMEOUT,
}


_IAM_SCOPE = ["https://www.googleapis.com/auth/iam"]

_IAM_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
"https://iamcredentials.{}/v1/projects/-"
+ "/serviceAccounts/{}:generateAccessToken"
)

_IAM_SIGN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signBlob"
"https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:signBlob"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
"https://iamcredentials.{}/v1/" + "projects/-/serviceAccounts/{}:generateIdToken"
)


Expand Down Expand Up @@ -90,7 +87,9 @@ def _make_signing_request(self, message):
message = _helpers.to_bytes(message)

method = "POST"
url = _IAM_SIGN_ENDPOINT.format(self._service_account_email)
url = _IAM_SIGN_ENDPOINT.format(
self._credentials.universe_domain, self._service_account_email
)
headers = {"Content-Type": "application/json"}
body = json.dumps(
{"payload": base64.b64encode(message).decode("utf-8")}
Expand Down
16 changes: 12 additions & 4 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def _make_iam_token_request(
request, principal, headers, body, iam_endpoint_override=None
request, principal, headers, body, universe_domain, iam_endpoint_override=None
):
"""Makes a request to the Google Cloud IAM service for an access token.
Args:
Expand All @@ -67,7 +67,9 @@ def _make_iam_token_request(
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(principal)
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(
universe_domain, principal
)

body = json.dumps(body).encode("utf-8")

Expand Down Expand Up @@ -219,6 +221,8 @@ def __init__(
and self._source_credentials._always_use_jwt_access
):
self._source_credentials._create_self_signed_jwt(None)

self._universe_domain = source_credentials.universe_domain
self._target_principal = target_principal
self._target_scopes = target_scopes
self._delegates = delegates
Expand Down Expand Up @@ -271,13 +275,16 @@ def _update_token(self, request):
principal=self._target_principal,
headers=headers,
body=body,
universe_domain=self.universe_domain,
iam_endpoint_override=self._iam_endpoint_override,
)

def sign_bytes(self, message):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(self._target_principal)
iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(
self.universe_domain, self._target_principal
)

body = {
"payload": base64.b64encode(message).decode("utf-8"),
Expand Down Expand Up @@ -428,7 +435,8 @@ def refresh(self, request):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format(
self._target_credentials.signer_email
self._target_credentials.universe_domain,
self._target_credentials.signer_email,
)

body = {
Expand Down
9 changes: 7 additions & 2 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ def jwt_grant(request, token_uri, assertion, can_retry=True):


def call_iam_generate_id_token_endpoint(
request, iam_id_token_endpoint, signer_email, audience, access_token
request,
iam_id_token_endpoint,
signer_email,
audience,
access_token,
universe_domain,
):
"""Call iam.generateIdToken endpoint to get ID token.

Expand All @@ -339,7 +344,7 @@ def call_iam_generate_id_token_endpoint(

response_data = _token_endpoint_request(
request,
iam_id_token_endpoint.format(signer_email),
iam_id_token_endpoint.format(universe_domain, signer_email),
body,
access_token=access_token,
use_json=True,
Expand Down
1 change: 1 addition & 0 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def _refresh_with_iam_endpoint(self, request):
self.signer_email,
self._target_audience,
jwt_credentials.token.decode(),
self._universe_domain,
)

@_helpers.copy_docstring(credentials.Credentials)
Expand Down
20 changes: 20 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,16 @@ def test_with_target_audience_integration(self):
},
)

# mock information about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock token for credentials
responses.add(
responses.GET,
Expand Down Expand Up @@ -659,6 +669,16 @@ def test_with_quota_project_integration(self):
},
)

# stubby response about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock sign blob endpoint
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
Expand Down
2 changes: 2 additions & 0 deletions tests/oauth2/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_call_iam_generate_id_token_endpoint():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com",
)

assert (
Expand Down Expand Up @@ -361,6 +362,7 @@ def test_call_iam_generate_id_token_endpoint_no_id_token():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com",
)
assert excinfo.match("No ID token in response")

Expand Down
8 changes: 5 additions & 3 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
Expand All @@ -798,6 +798,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "googleapis.com"

@mock.patch(
"google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True
Expand All @@ -811,18 +812,19 @@ def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
assert (
iam_endpoint
== "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken"
== "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}:generateIdToken"
)
assert signer_email == "service-account@example.com"
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "fake-universe"

@mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
def test_before_request_refreshes(self, id_token_jwt_grant):
Expand Down
130 changes: 125 additions & 5 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def test_get_cred_info(self):
"principal": "impersonated@project.iam.gserviceaccount.com",
}

def test_universe_domain_matching_source(self):
source_credentials = service_account.Credentials(
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert credentials.universe_domain == "foo.bar"

def test__make_copy_get_cred_info(self):
credentials = self.make_credentials()
credentials._cred_file_path = "/path/to/file"
Expand Down Expand Up @@ -231,6 +238,38 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials):
source_credentials = service_account.Credentials(
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
token = "token"

expire_time = (
_helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
).isoformat("T") + "Z"
response_body = {"accessToken": token, "expireTime": expire_time}

request = self.make_request(
data=json.dumps(response_body),
status=http_client.OK,
use_data_bytes=use_data_bytes,
)

credentials.refresh(request)

assert credentials.valid
assert not credentials.expired
# Confirm override endpoint used.
request_kwargs = request.call_args[1]
assert (
request_kwargs["url"]
== "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateAccessToken"
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_iam_endpoint_override(
self, use_data_bytes, mock_donor_credentials
Expand Down Expand Up @@ -397,6 +436,38 @@ def test_service_account_email(self):

def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
credentials = self.make_credentials(lifetime=None)
expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob"
self._sign_bytes_helper(
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
)

def test_sign_bytes_nonGdu(
self, mock_donor_credentials, mock_authorizedsession_sign
):
source_credentials = service_account.Credentials(
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob"
self._sign_bytes_helper(
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
)

def _sign_bytes_helper(
self,
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
):
token = "token"

expire_time = (
Expand All @@ -412,11 +483,19 @@ def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
request.return_value = response

credentials.refresh(request)

assert credentials.valid
assert not credentials.expired

signature = credentials.sign_bytes(b"signed bytes")
mock_authorizedsession_sign.assert_called_with(
mock.ANY,
"POST",
expected_url,
None,
json={"payload": "c2lnbmVkIGJ5dGVz", "delegates": []},
headers={"Content-Type": "application/json"},
)

assert signature == b"signature"

def test_sign_bytes_failure(self):
Expand Down Expand Up @@ -563,6 +642,45 @@ def test_id_token_from_credential(
self, mock_donor_credentials, mock_authorizedsession_idtoken
):
credentials = self.make_credentials(lifetime=None)
target_credentials = self.make_credentials(lifetime=None)
expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateIdToken"
self._test_id_token_helper(
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
)

def test_id_token_from_credential_nonGdu(
self, mock_donor_credentials, mock_authorizedsession_idtoken
):
source_credentials = service_account.Credentials(
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
target_credentials = self.make_credentials(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is same cred created 2 times?

lifetime=None, source_credentials=source_credentials
)
expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateIdToken"
self._test_id_token_helper(
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
)

def _test_id_token_helper(
self,
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
):
token = "token"
target_audience = "https://foo.bar"

Expand All @@ -580,17 +698,19 @@ def test_id_token_from_credential(
assert credentials.valid
assert not credentials.expired

new_credentials = self.make_credentials(lifetime=None)

id_creds = impersonated_credentials.IDTokenCredentials(
credentials, target_audience=target_audience, include_email=True
)
id_creds = id_creds.from_credentials(target_credentials=new_credentials)
id_creds = id_creds.from_credentials(target_credentials=target_credentials)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this line existing? IIUC, removing this will have no impact.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to remove the target_audience?

id_creds.refresh(request)

args = mock_authorizedsession_idtoken.call_args.args

assert args[2] == expected_url

assert id_creds.token == ID_TOKEN_DATA
assert id_creds._include_email is True
assert id_creds._target_credentials is new_credentials
assert id_creds._target_credentials is target_credentials

def test_id_token_with_target_audience(
self, mock_donor_credentials, mock_authorizedsession_idtoken
Expand Down