Skip to content

Commit

Permalink
feat: adding domain-wide delegation flow in impersonated credential (#…
Browse files Browse the repository at this point in the history
…1624)

* Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth.

* Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth.

* Minor fixes to dwd flow in impersonation

* Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth.

* deleted repeated

* delete repeated code

* Fixing where source credentials authentication header info is, and target scopes.

* Formatted code to uniform standard

* Fixing lint and coverage failures from kokoro tests

---------

Co-authored-by: Brian Jung <brianhmj@google.com>
Co-authored-by: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 17, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 1972c7b commit 34ee3fe
Showing 3 changed files with 225 additions and 1 deletion.
5 changes: 5 additions & 0 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,11 @@
+ "/serviceAccounts/{}:signBlob"
)

_IAM_SIGNJWT_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signJwt"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
101 changes: 100 additions & 1 deletion google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
@@ -38,12 +38,15 @@
from google.auth import iam
from google.auth import jwt
from google.auth import metrics
from google.oauth2 import _client


_REFRESH_ERROR = "Unable to acquire impersonated credentials"

_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds

_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"


def _make_iam_token_request(
request,
@@ -177,6 +180,7 @@ def __init__(
target_principal,
target_scopes,
delegates=None,
subject=None,
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
quota_project_id=None,
iam_endpoint_override=None,
@@ -204,9 +208,12 @@ def __init__(
quota_project_id (Optional[str]): The project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
iam_endpoint_override (Optiona[str]): The full IAM endpoint override
iam_endpoint_override (Optional[str]): The full IAM endpoint override
with the target_principal embedded. This is useful when supporting
impersonation with regional endpoints.
subject (Optional[str]): sub field of a JWT. This field should only be set
if you wish to impersonate as a user. This feature is useful when
using domain wide delegation.
"""

super(Credentials, self).__init__()
@@ -231,6 +238,7 @@ def __init__(
self._target_principal = target_principal
self._target_scopes = target_scopes
self._delegates = delegates
self._subject = subject
self._lifetime = lifetime or _DEFAULT_TOKEN_LIFETIME_SECS
self.token = None
self.expiry = _helpers.utcnow()
@@ -275,6 +283,39 @@ def _update_token(self, request):
# Apply the source credentials authentication info.
self._source_credentials.apply(headers)

# If a subject is specified a domain-wide delegation auth-flow is initiated
# to impersonate as the provided subject (user).
if self._subject:
if self.universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN:
raise exceptions.GoogleAuthError(
"Domain-wide delegation is not supported in universes other "
+ "than googleapis.com"
)

now = _helpers.utcnow()
payload = {
"iss": self._target_principal,
"scope": _helpers.scopes_to_string(self._target_scopes or ()),
"sub": self._subject,
"aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT,
"iat": _helpers.datetime_to_secs(now),
"exp": _helpers.datetime_to_secs(now) + _DEFAULT_TOKEN_LIFETIME_SECS,
}

assertion = _sign_jwt_request(
request=request,
principal=self._target_principal,
headers=headers,
payload=payload,
delegates=self._delegates,
)

self.token, self.expiry, _ = _client.jwt_grant(
request, _GOOGLE_OAUTH2_TOKEN_ENDPOINT, assertion
)

return

self.token, self.expiry = _make_iam_token_request(
request=request,
principal=self._target_principal,
@@ -478,3 +519,61 @@ def refresh(self, request):
self.expiry = datetime.utcfromtimestamp(
jwt.decode(id_token, verify=False)["exp"]
)


def _sign_jwt_request(request, principal, headers, payload, delegates=[]):
"""Makes a request to the Google Cloud IAM service to sign a JWT using a
service account's system-managed private key.
Args:
request (Request): The Request object to use.
principal (str): The principal to request an access token for.
headers (Mapping[str, str]): Map of headers to transmit.
payload (Mapping[str, str]): The JWT payload to sign. Must be a
serialized JSON object that contains a JWT Claims Set.
delegates (Sequence[str]): The chained list of delegates required
to grant the final access_token. If set, the sequence of
identities must have "Service Account Token Creator" capability
granted to the prceeding identity. For example, if set to
[serviceAccountB, serviceAccountC], the source_credential
must have the Token Creator role on serviceAccountB.
serviceAccountB must have the Token Creator on
serviceAccountC.
Finally, C must have Token Creator on target_principal.
If left unset, source_credential must have that role on
target_principal.
Raises:
google.auth.exceptions.TransportError: Raised if there is an underlying
HTTP connection error
google.auth.exceptions.RefreshError: Raised if the impersonated
credentials are not available. Common reasons are
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam._IAM_SIGNJWT_ENDPOINT.format(principal)

body = {"delegates": delegates, "payload": json.dumps(payload)}
body = json.dumps(body).encode("utf-8")

response = request(url=iam_endpoint, method="POST", headers=headers, body=body)

# support both string and bytes type response.data
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)

if response.status != http_client.OK:
raise exceptions.RefreshError(_REFRESH_ERROR, response_body)

try:
jwt_response = json.loads(response_body)
signed_jwt = jwt_response["signedJwt"]
return signed_jwt

except (KeyError, ValueError) as caught_exc:
new_exc = exceptions.RefreshError(
"{}: No signed JWT in response.".format(_REFRESH_ERROR), response_body
)
raise new_exc from caught_exc
120 changes: 120 additions & 0 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
@@ -71,6 +71,17 @@ def mock_donor_credentials():
yield grant


@pytest.fixture
def mock_dwd_credentials():
with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
grant.return_value = (
"1/fFAGRNJasdfz70BzhT3Zg",
_helpers.utcnow() + datetime.timedelta(seconds=500),
{},
)
yield grant


class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
@@ -123,6 +134,7 @@ def make_credentials(
source_credentials=SOURCE_CREDENTIALS,
lifetime=LIFETIME,
target_principal=TARGET_PRINCIPAL,
subject=None,
iam_endpoint_override=None,
):

@@ -132,6 +144,7 @@ def make_credentials(
target_scopes=self.TARGET_SCOPES,
delegates=self.DELEGATES,
lifetime=lifetime,
subject=subject,
iam_endpoint_override=iam_endpoint_override,
)

@@ -238,6 +251,28 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_with_subject_success(self, use_data_bytes, mock_dwd_credentials):
credentials = self.make_credentials(subject="test@email.com", lifetime=None)

response_body = {"signedJwt": "example_signed_jwt"}

request = self.make_request(
data=json.dumps(response_body),
status=http_client.OK,
use_data_bytes=use_data_bytes,
)

with mock.patch(
"google.auth.metrics.token_request_access_token_impersonate",
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
):
credentials.refresh(request)

assert credentials.valid
assert not credentials.expired
assert credentials.token == "1/fFAGRNJasdfz70BzhT3Zg"

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials):
source_credentials = service_account.Credentials(
@@ -418,6 +453,33 @@ def test_refresh_failure_http_error(self, mock_donor_credentials):
assert not credentials.valid
assert credentials.expired

def test_refresh_failure_subject_with_nondefault_domain(
self, mock_donor_credentials
):
source_credentials = service_account.Credentials(
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
source_credentials=source_credentials, subject="test@email.com"
)

expire_time = (_helpers.utcnow().replace(microsecond=0)).isoformat("T") + "Z"
response_body = {"accessToken": "token", "expireTime": expire_time}
request = self.make_request(
data=json.dumps(response_body), status=http_client.OK
)

with pytest.raises(exceptions.GoogleAuthError) as excinfo:
credentials.refresh(request)

assert excinfo.match(
"Domain-wide delegation is not supported in universes other "
+ "than googleapis.com"
)

assert not credentials.valid
assert credentials.expired

def test_expired(self):
credentials = self.make_credentials(lifetime=None)
assert credentials.expired
@@ -810,3 +872,61 @@ def test_id_token_with_quota_project(
id_creds.refresh(request)

assert id_creds.quota_project_id == "project-foo"

def test_sign_jwt_request_success(self):
principal = "foo@example.com"
expected_signed_jwt = "correct_signed_jwt"

response_body = {"keyId": "1", "signedJwt": expected_signed_jwt}
request = self.make_request(
data=json.dumps(response_body), status=http_client.OK
)

signed_jwt = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert signed_jwt == expected_signed_jwt
request.assert_called_once_with(
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@example.com:signJwt",
method="POST",
headers={},
body=json.dumps({"delegates": [], "payload": json.dumps({})}).encode(
"utf-8"
),
)

def test_sign_jwt_request_http_error(self):
principal = "foo@example.com"

request = self.make_request(
data="error_message", status=http_client.BAD_REQUEST
)

with pytest.raises(exceptions.RefreshError) as excinfo:
_ = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

assert excinfo.value.args[0] == "Unable to acquire impersonated credentials"
assert excinfo.value.args[1] == "error_message"

def test_sign_jwt_request_invalid_response_error(self):
principal = "foo@example.com"

request = self.make_request(data="invalid_data", status=http_client.OK)

with pytest.raises(exceptions.RefreshError) as excinfo:
_ = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

assert (
excinfo.value.args[0]
== "Unable to acquire impersonated credentials: No signed JWT in response."
)
assert excinfo.value.args[1] == "invalid_data"

0 comments on commit 34ee3fe

Please # to comment.