Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: adds support for X509 workload credential type #1541

Merged
merged 11 commits into from
Jul 2, 2024
18 changes: 16 additions & 2 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import copy
from dataclasses import dataclass
import datetime
import functools
import io
import json
import re
Expand All @@ -40,6 +41,7 @@
from google.auth import exceptions
from google.auth import impersonated_credentials
from google.auth import metrics
from google.auth.transport.requests import _MutualTlsAdapter
from google.oauth2 import sts
from google.oauth2 import utils

Expand Down Expand Up @@ -393,12 +395,18 @@ def get_project_id(self, request):
@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
scopes = self._scopes if self._scopes is not None else self._default_scopes
auth_request = request

# if mtls is required, wrap the incoming request in a partial to set the cert.
if self._should_add_mtls():
print("mtls yeah")
auth_request = functools.partial(request, cert=self._get_mtls_cert())

if self._should_initialize_impersonated_credentials():
self._impersonated_credentials = self._initialize_impersonated_credentials()

if self._impersonated_credentials:
self._impersonated_credentials.refresh(request)
self._impersonated_credentials.refresh(auth_request)
self.token = self._impersonated_credentials.token
self.expiry = self._impersonated_credentials.expiry
else:
Expand All @@ -414,7 +422,7 @@ def refresh(self, request):
)
}
response_data = self._sts_client.exchange_token(
request=request,
request=auth_request,
grant_type=_STS_GRANT_TYPE,
subject_token=self.retrieve_subject_token(request),
subject_token_type=self._subject_token_type,
Expand Down Expand Up @@ -523,6 +531,12 @@ def _create_default_metrics_options(self):

return metrics_options

def _should_add_mtls(self):
return False

def _get_mtls_cert(self):
raise NotImplementedError("_get_mtls_cert must be implemented.")

@classmethod
def from_info(cls, info, **kwargs):
"""Creates a Credentials instance from parsed external account info.
Expand Down
131 changes: 101 additions & 30 deletions google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth import external_account
from google.auth.transport import _mtls_helper


class SubjectTokenSupplier(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -141,6 +142,14 @@ def get_subject_token(self, context, request):
)


class _X509Supplier(SubjectTokenSupplier):
""" Internal implementation of subject token supplier for X509 workload credentials, always returns an empty string."""

@_helpers.copy_docstring(SubjectTokenSupplier)
def get_subject_token(self, context, request):
return ""


def _parse_token_data(token_content, format_type="text", subject_token_field_name=None):
if format_type == "text":
token = token_content.content
Expand Down Expand Up @@ -247,6 +256,7 @@ def __init__(
self._subject_token_supplier = subject_token_supplier
self._credential_source_file = None
self._credential_source_url = None
self._credential_source_certificate = None
else:
if not isinstance(credential_source, Mapping):
self._credential_source_executable = None
Expand All @@ -255,76 +265,93 @@ def __init__(
)
self._credential_source_file = credential_source.get("file")
self._credential_source_url = credential_source.get("url")
self._credential_source_headers = credential_source.get("headers")
credential_source_format = credential_source.get("format", {})
# Get credential_source format type. When not provided, this
# defaults to text.
self._credential_source_format_type = (
credential_source_format.get("type") or "text"
)
self._credential_source_certificate = credential_source.get("certificate")

# environment_id is only supported in AWS or dedicated future external
# account credentials.
if "environment_id" in credential_source:
raise exceptions.MalformedError(
"Invalid Identity Pool credential_source field 'environment_id'"
)
if self._credential_source_format_type not in ["text", "json"]:
raise exceptions.MalformedError(
"Invalid credential_source format '{}'".format(
self._credential_source_format_type

# check that only one of file, url, or certificate are provided.
if (
sum(
map(
bool,
[
self._credential_source_file,
self._credential_source_url,
self._credential_source_certificate,
],
)
)
# For JSON types, get the required subject_token field name.
if self._credential_source_format_type == "json":
self._credential_source_field_name = credential_source_format.get(
"subject_token_field_name"
)
if self._credential_source_field_name is None:
raise exceptions.MalformedError(
"Missing subject_token_field_name for JSON credential_source format"
)
else:
self._credential_source_field_name = None

if self._credential_source_file and self._credential_source_url:
> 1
):
raise exceptions.MalformedError(
"Ambiguous credential_source. 'file' is mutually exclusive with 'url'."
"Ambiguous credential_source. 'file', 'url', and 'certificate' are mutually exclusive.."
)
if not self._credential_source_file and not self._credential_source_url:
if (
not self._credential_source_file
and not self._credential_source_url
and not self._credential_source_certificate
):
raise exceptions.MalformedError(
"Missing credential_source. A 'file' or 'url' must be provided."
"Missing credential_source. A 'file', 'url', or 'certificate' must be provided."
)

if self._credential_source_certificate:
self._validate_certificate_credential_source()
else:
self._validate_file_url_credential_source(credential_source)

if self._credential_source_file:
self._subject_token_supplier = _FileSupplier(
self._credential_source_file,
self._credential_source_format_type,
self._credential_source_field_name,
)
else:
elif self._credential_source_url:
self._subject_token_supplier = _UrlSupplier(
self._credential_source_url,
self._credential_source_format_type,
self._credential_source_field_name,
self._credential_source_headers,
)
else:
self._subject_token_supplier = _X509Supplier()

@_helpers.copy_docstring(external_account.Credentials)
def retrieve_subject_token(self, request):
return self._subject_token_supplier.get_subject_token(
self._supplier_context, request
)

def _get_mtls_cert(self):
if self._credential_source_certificate == None:
raise exceptions.RefreshError(
'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.'
)
else:
return _mtls_helper._get_workload_cert_and_key_paths(
self._certificate_config_location
)

def _should_add_mtls(self):
return self._credential_source_certificate is not None

def _create_default_metrics_options(self):
metrics_options = super(Credentials, self)._create_default_metrics_options()
# Check that credential source is a dict before checking for file vs url. This check needs to be done
# Check that credential source is a dict before checking for credential type. This check needs to be done
# here because the external_account credential constructor needs to pass the metrics options to the
# impersonated credential object before the identity_pool credentials are validated.
if isinstance(self._credential_source, Mapping):
if self._credential_source.get("file"):
metrics_options["source"] = "file"
else:
elif self._credential_source.get("url"):
metrics_options["source"] = "url"
else:
metrics_options["source"] = "x509"
else:
metrics_options["source"] = "programmatic"
return metrics_options
Expand All @@ -339,6 +366,50 @@ def _constructor_args(self):
args.update({"subject_token_supplier": self._subject_token_supplier})
return args

def _validate_certificate_credential_source(self):
self._certificate_config_location = self._credential_source_certificate.get(
"certificate_config_location"
)
use_default = self._credential_source_certificate.get(
"use_default_certificate_config"
)
if self._certificate_config_location:
if use_default:
raise exceptions.MalformedError(
"Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true."
)
else:
if not use_default:
raise exceptions.MalformedError(
"Invalid certificate configuration, use_default_certificate_config should be true if no certificate_config_location is provided."
)

def _validate_file_url_credential_source(self, credential_source):
self._credential_source_headers = credential_source.get("headers")
credential_source_format = credential_source.get("format", {})
# Get credential_source format type. When not provided, this
# defaults to text.
self._credential_source_format_type = (
credential_source_format.get("type") or "text"
)
if self._credential_source_format_type not in ["text", "json"]:
raise exceptions.MalformedError(
"Invalid credential_source format '{}'".format(
self._credential_source_format_type
)
)
# For JSON types, get the required subject_token field name.
if self._credential_source_format_type == "json":
self._credential_source_field_name = credential_source_format.get(
"subject_token_field_name"
)
if self._credential_source_field_name is None:
raise exceptions.MalformedError(
"Missing subject_token_field_name for JSON credential_source format"
)
else:
self._credential_source_field_name = None

@classmethod
def from_info(cls, info, **kwargs):
"""Creates an Identity Pool Credentials instance from parsed external account info.
Expand Down
75 changes: 43 additions & 32 deletions google/auth/transport/_mtls_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,50 @@ def _get_workload_cert_and_key(certificate_config_path=None):
google.auth.exceptions.ClientCertError: if problems occurs when retrieving
the certificate or key information.
"""
absolute_path = _get_cert_config_path(certificate_config_path)

cert_path, key_path = _get_workload_cert_and_key_paths(certificate_config_path)

if cert_path is None and key_path is None:
return None, None

return _read_cert_and_key_files(cert_path, key_path)


def _get_cert_config_path(certificate_config_path=None):
"""Gets the certificate configuration full path using the following order of precedence:

1: Explicit override, if set
2: Environment variable, if set
3: Well-known location

Returns "None" if the selected config file does not exist.

Args:
certificate_config_path (string): The certificate config path. If provided, the well known
location and environment variable will be ignored.

Returns:
The absolute path of the certificate config file, and None if the file does not exist.
"""

if certificate_config_path is None:
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
if env_path is not None and env_path != "":
certificate_config_path = env_path
else:
certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH

certificate_config_path = path.expanduser(certificate_config_path)
if not path.exists(certificate_config_path):
return None
return certificate_config_path


def _get_workload_cert_and_key_paths(config_path):
absolute_path = _get_cert_config_path(config_path)
if absolute_path is None:
return None, None

data = _load_json_file(absolute_path)

if "cert_configs" not in data:
Expand Down Expand Up @@ -142,37 +183,7 @@ def _get_workload_cert_and_key(certificate_config_path=None):
)
key_path = workload["key_path"]

return _read_cert_and_key_files(cert_path, key_path)


def _get_cert_config_path(certificate_config_path=None):
"""Gets the certificate configuration full path using the following order of precedence:

1: Explicit override, if set
2: Environment variable, if set
3: Well-known location

Returns "None" if the selected config file does not exist.

Args:
certificate_config_path (string): The certificate config path. If provided, the well known
location and environment variable will be ignored.

Returns:
The absolute path of the certificate config file, and None if the file does not exist.
"""

if certificate_config_path is None:
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
if env_path is not None and env_path != "":
certificate_config_path = env_path
else:
certificate_config_path = _CERTIFICATE_CONFIGURATION_DEFAULT_PATH

certificate_config_path = path.expanduser(certificate_config_path)
if not path.exists(certificate_config_path):
return None
return certificate_config_path
return cert_path, key_path


def _read_cert_and_key_files(cert_path, key_path):
Expand Down
Loading