diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py new file mode 100644 index 000000000..1ce9fa554 --- /dev/null +++ b/google/auth/transport/_mtls_helper.py @@ -0,0 +1,116 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for getting mTLS cert and key, for internal use only.""" + +import json +import logging +from os import path +import re +import subprocess + +CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" +_CERT_PROVIDER_COMMAND = "cert_provider_command" +_CERT_REGEX = re.compile( + b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL +) + +# support various format of key files, e.g. +# "-----BEGIN PRIVATE KEY-----...", +# "-----BEGIN EC PRIVATE KEY-----...", +# "-----BEGIN RSA PRIVATE KEY-----..." +_KEY_REGEX = re.compile( + b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?", + re.DOTALL, +) + +_LOGGER = logging.getLogger(__name__) + + +def _check_dca_metadata_path(metadata_path): + """Checks for context aware metadata. If it exists, returns the absolute path; + otherwise returns None. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + str: absolute path if exists and None otherwise. + """ + metadata_path = path.expanduser(metadata_path) + if not path.exists(metadata_path): + _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path) + return None + return metadata_path + + +def _read_dca_metadata_file(metadata_path): + """Loads context aware metadata from the given path. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + Dict[str, str]: The metadata. + + Raises: + ValueError: If failed to parse metadata as JSON. + """ + with open(metadata_path) as f: + metadata = json.load(f) + + return metadata + + +def get_client_ssl_credentials(metadata_json): + """Returns the client side mTLS cert and key. + + Args: + metadata_json (Dict[str, str]): metadata JSON file which contains the cert + provider command. + + Returns: + Tuple[bytes, bytes]: client certificate and key, both in PEM format. + + Raises: + OSError: If the cert provider command failed to run. + RuntimeError: If the cert provider command has a runtime error. + ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key. + """ + # TODO: implement an in-memory cache of cert and key so we don't have to + # run cert provider command every time. + + # Check the cert provider command existence in the metadata json file. + if _CERT_PROVIDER_COMMAND not in metadata_json: + raise ValueError("Cert provider command is not found") + + # Execute the command. It throws OsError in case of system failure. + command = metadata_json[_CERT_PROVIDER_COMMAND] + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + + # Check cert provider command execution error. + if process.returncode != 0: + raise RuntimeError( + "Cert provider command returns non-zero status code %s" % process.returncode + ) + + # Extract certificate (chain) and key. + cert_match = re.findall(_CERT_REGEX, stdout) + if len(cert_match) != 1: + raise ValueError("Client SSL certificate is missing or invalid") + key_match = re.findall(_KEY_REGEX, stdout) + if len(key_match) != 1: + raise ValueError("Client SSL key is missing or invalid") + return cert_match[0], key_match[0] diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index fb90fbb4b..ca387392e 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from concurrent import futures +import logging import six +from google.auth.transport import _mtls_helper + try: import grpc except ImportError as caught_exc: # pragma: NO COVER @@ -31,6 +34,8 @@ caught_exc, ) +_LOGGER = logging.getLogger(__name__) + class AuthMetadataPlugin(grpc.AuthMetadataPlugin): """A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each @@ -92,7 +97,12 @@ def __del__(self): def secure_authorized_channel( - credentials, request, target, ssl_credentials=None, **kwargs + credentials, + request, + target, + ssl_credentials=None, + client_cert_callback=None, + **kwargs ): """Creates a secure authorized gRPC channel. @@ -114,11 +124,86 @@ def secure_authorized_channel( # Create a channel. channel = google.auth.transport.grpc.secure_authorized_channel( - credentials, 'speech.googleapis.com:443', request) + credentials, regular_endpoint, request, + ssl_credentials=grpc.ssl_channel_credentials()) # Use the channel to create a stub. cloud_speech.create_Speech_stub(channel) + Usage: + + There are actually a couple of options to create a channel, depending on if + you want to create a regular or mutual TLS channel. + + First let's list the endpoints (regular vs mutual TLS) to choose from:: + + regular_endpoint = 'speech.googleapis.com:443' + mtls_endpoint = 'speech.mtls.googleapis.com:443' + + Option 1: create a regular (non-mutual) TLS channel by explicitly setting + the ssl_credentials:: + + regular_ssl_credentials = grpc.ssl_channel_credentials() + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, regular_endpoint, request, + ssl_credentials=regular_ssl_credentials) + + Option 2: create a mutual TLS channel by calling a callback which returns + the client side certificate and the key:: + + def my_client_cert_callback(): + code_to_load_client_cert_and_key() + if loaded: + return (pem_cert_bytes, pem_key_bytes) + raise MyClientCertFailureException() + + try: + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, mtls_endpoint, request, + client_cert_callback=my_client_cert_callback) + except MyClientCertFailureException: + # handle the exception + + Option 3: use application default SSL credentials. It searches and uses + the command in a context aware metadata file, which is available on devices + with endpoint verification support. + See https://cloud.google.com/endpoint-verification/docs/overview:: + + try: + default_ssl_credentials = SslCredentials() + except: + # Exception can be raised if the context aware metadata is malformed. + # See :class:`SslCredentials` for the possible exceptions. + + # Choose the endpoint based on the SSL credentials type. + if default_ssl_credentials.is_mtls: + endpoint_to_use = mtls_endpoint + else: + endpoint_to_use = regular_endpoint + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, endpoint_to_use, request, + ssl_credentials=default_ssl_credentials) + + Option 4: not setting ssl_credentials and client_cert_callback. For devices + without endpoint verification support, a regular TLS channel is created; + otherwise, a mutual TLS channel is created, however, the call should be + wrapped in a try/except block in case of malformed context aware metadata. + + The following code uses regular_endpoint, it works the same no matter the + created channle is regular or mutual TLS. Regular endpoint ignores client + certificate and key:: + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, regular_endpoint, request) + + The following code uses mtls_endpoint, if the created channle is regular, + and API mtls_endpoint is confgured to require client SSL credentials, API + calls using this channel will be rejected:: + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, mtls_endpoint, request) + Args: credentials (google.auth.credentials.Credentials): The credentials to add to requests. @@ -129,10 +214,33 @@ def secure_authorized_channel( target (str): The host and port of the service. ssl_credentials (grpc.ChannelCredentials): Optional SSL channel credentials. This can be used to specify different certificates. + This argument is mutually exclusive with client_cert_callback; + providing both will raise an exception. + If ssl_credentials and client_cert_callback are None, application + default SSL credentials will be used. + client_cert_callback (Callable[[], (bytes, bytes)]): Optional + callback function to obtain client certicate and key for mutual TLS + connection. This argument is mutually exclusive with + ssl_credentials; providing both will raise an exception. + If ssl_credentials and client_cert_callback are None, application + default SSL credentials will be used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: grpc.Channel: The created gRPC channel. + + Raises: + OSError: If the cert provider command launch fails during the application + default SSL credentials loading process on devices with endpoint + verification support. + RuntimeError: If the cert provider command has a runtime error during the + application default SSL credentials loading process on devices with + endpoint verification support. + ValueError: + If the context aware metadata file is malformed or if the cert provider + command doesn't produce both client certificate and key during the + application default SSL credentials loading process on devices with + endpoint verification support. """ # Create the metadata plugin for inserting the authorization header. metadata_plugin = AuthMetadataPlugin(credentials, request) @@ -140,8 +248,24 @@ def secure_authorized_channel( # Create a set of grpc.CallCredentials using the metadata plugin. google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) - if ssl_credentials is None: - ssl_credentials = grpc.ssl_channel_credentials() + if ssl_credentials and client_cert_callback: + raise ValueError( + "Received both ssl_credentials and client_cert_callback; " + "these are mutually exclusive." + ) + + # If SSL credentials are not explicitly set, try client_cert_callback and ADC. + if not ssl_credentials: + if client_cert_callback: + # Use the callback if provided. + cert, key = client_cert_callback() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + # Use application default SSL credentials. + adc_ssl_credentils = SslCredentials() + ssl_credentials = adc_ssl_credentils.ssl_credentials # Combine the ssl credentials and the authorization credentials. composite_credentials = grpc.composite_channel_credentials( @@ -149,3 +273,59 @@ def secure_authorized_channel( ) return grpc.secure_channel(target, composite_credentials, **kwargs) + + +class SslCredentials: + """Class for application default SSL credentials. + + For devices with endpoint verification support, a device certificate will be + automatically loaded and mutual TLS will be established. + See https://cloud.google.com/endpoint-verification/docs/overview. + """ + + def __init__(self): + # Load client SSL credentials. + self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path( + _mtls_helper.CONTEXT_AWARE_METADATA_PATH + ) + if self._context_aware_metadata_path: + self._is_mtls = True + else: + self._is_mtls = False + + @property + def ssl_credentials(self): + """Get the created SSL channel credentials. + + For devices with endpoint verification support, if the device certificate + loading has any problems, corresponding exceptions will be raised. For + a device without endpoint verification support, no exceptions will be + raised. + + Returns: + grpc.ChannelCredentials: The created grpc channel credentials. + + Raises: + OSError: If the cert provider command launch fails. + RuntimeError: If the cert provider command has a runtime error. + ValueError: + If the context aware metadata file is malformed or if the cert provider + command doesn't produce both the client certificate and key. + """ + if self._context_aware_metadata_path: + metadata = _mtls_helper._read_dca_metadata_file( + self._context_aware_metadata_path + ) + cert, key = _mtls_helper.get_client_ssl_credentials(metadata) + self._ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_credentials = grpc.ssl_channel_credentials() + + return self._ssl_credentials + + @property + def is_mtls(self): + """Indicates if the created SSL channel credentials is mutual TLS.""" + return self._is_mtls diff --git a/tests/data/context_aware_metadata.json b/tests/data/context_aware_metadata.json new file mode 100644 index 000000000..ec40e783f --- /dev/null +++ b/tests/data/context_aware_metadata.json @@ -0,0 +1,6 @@ +{ + "cert_provider_command":[ + "/opt/google/endpoint-verification/bin/SecureConnectHelper", + "--print_certificate"], + "device_resource_ids":["11111111-1111-1111"] +} diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py new file mode 100644 index 000000000..6e7175f17 --- /dev/null +++ b/tests/transport/test__mtls_helper.py @@ -0,0 +1,177 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +import mock +import pytest + +from google.auth.transport import _mtls_helper + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") + +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + +CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} + +CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {} + + +def check_cert_and_key(content, expected_cert, expected_key): + success = True + + cert_match = re.findall(_mtls_helper._CERT_REGEX, content) + success = success and len(cert_match) == 1 and cert_match[0] == expected_cert + + key_match = re.findall(_mtls_helper._KEY_REGEX, content) + success = success and len(key_match) == 1 and key_match[0] == expected_key + + return success + + +class TestCertAndKeyRegex(object): + def test_cert_and_key(self): + # Test single cert and single key + check_cert_and_key( + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + ) + check_cert_and_key( + PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + ) + + # Test cert chain and single key + check_cert_and_key( + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + ) + check_cert_and_key( + PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + ) + + def test_key(self): + # Create some fake keys for regex check. + KEY = b"""-----BEGIN PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END PRIVATE KEY-----""" + RSA_KEY = b"""-----BEGIN RSA PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END RSA PRIVATE KEY-----""" + EC_KEY = b"""-----BEGIN EC PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END EC PRIVATE KEY-----""" + + check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY) + check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY) + check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY) + + +class TestCheckaMetadataPath(object): + def test_success(self): + metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is not None + + def test_failure(self): + metadata_path = os.path.join(DATA_DIR, "not_exists.json") + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is None + + +class TestReadMetadataFile(object): + def test_success(self): + metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") + metadata = _mtls_helper._read_dca_metadata_file(metadata_path) + + assert "cert_provider_command" in metadata + + def test_file_not_json(self): + # read a file which is not json format. + metadata_path = os.path.join(DATA_DIR, "privatekey.pem") + with pytest.raises(ValueError): + _mtls_helper._read_dca_metadata_file(metadata_path) + + +class TestGetClientSslCredentials(object): + def create_mock_process(self, output, error): + # There are two steps to execute a script with subprocess.Popen. + # (1) process = subprocess.Popen([comannds]) + # (2) stdout, stderr = process.communicate() + # This function creates a mock process which can be returned by a mock + # subprocess.Popen. The mock process returns the given output and error + # when mock_process.communicate() is called. + mock_process = mock.Mock() + attrs = {"communicate.return_value": (output, error), "returncode": 0} + mock_process.configure_mock(**attrs) + return mock_process + + @mock.patch("subprocess.Popen", autospec=True) + def test_success(self, mock_popen): + mock_popen.return_value = self.create_mock_process( + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b"" + ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == PUBLIC_CERT_BYTES + assert key == PRIVATE_KEY_BYTES + + @mock.patch("subprocess.Popen", autospec=True) + def test_success_with_cert_chain(self, mock_popen): + PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + mock_popen.return_value = self.create_mock_process( + PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b"" + ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == PUBLIC_CERT_CHAIN_BYTES + assert key == PRIVATE_KEY_BYTES + + def test_missing_cert_provider_command(self): + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_cert(self, mock_popen): + mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"") + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_key(self, mock_popen): + mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"") + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_cert_provider_returns_error(self, mock_popen): + mock_popen.return_value = self.create_mock_process(b"", b"some error") + mock_popen.return_value.returncode = 1 + with pytest.raises(RuntimeError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_popen_raise_exception(self, mock_popen): + mock_popen.side_effect = OSError() + with pytest.raises(OSError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 857c32bb9..23e62a213 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import os import time import mock @@ -31,6 +32,12 @@ except ImportError: # pragma: NO COVER HAS_GRPC = False +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json") +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.") @@ -87,70 +94,251 @@ def test_call_refresh(self): ) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) @mock.patch("grpc.composite_channel_credentials", autospec=True) @mock.patch("grpc.metadata_call_credentials", autospec=True) @mock.patch("grpc.ssl_channel_credentials", autospec=True) @mock.patch("grpc.secure_channel", autospec=True) -def test_secure_authorized_channel( - secure_channel, - ssl_channel_credentials, - metadata_call_credentials, - composite_channel_credentials, -): - credentials = CredentialsStub() - request = mock.create_autospec(transport.Request) - target = "example.com:80" - - channel = google.auth.transport.grpc.secure_authorized_channel( - credentials, request, target, options=mock.sentinel.options +class TestSecureAuthorizedChannel(object): + @mock.patch( + "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) + def test_secure_authorized_channel_adc( + self, + check_dca_metadata_path, + read_dca_metadata_file, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = CredentialsStub() + request = mock.create_autospec(transport.Request) + target = "example.com:80" + + # Mock the context aware metadata and client cert/key so mTLS SSL channel + # will be used. + check_dca_metadata_path.return_value = METADATA_PATH + read_dca_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + get_client_ssl_credentials.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, options=mock.sentinel.options + ) - # Check the auth plugin construction. - auth_plugin = metadata_call_credentials.call_args[0][0] - assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin) - assert auth_plugin._credentials == credentials - assert auth_plugin._request == request + # Check the auth plugin construction. + auth_plugin = metadata_call_credentials.call_args[0][0] + assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin) + assert auth_plugin._credentials == credentials + assert auth_plugin._request == request - # Check the ssl channel call. - assert ssl_channel_credentials.called + # Check the ssl channel call. + ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + ) - # Check the composite credentials call. - composite_channel_credentials.assert_called_once_with( - ssl_channel_credentials.return_value, metadata_call_credentials.return_value - ) + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_channel_credentials.return_value, metadata_call_credentials.return_value + ) - # Check the channel call. - secure_channel.assert_called_once_with( - target, - composite_channel_credentials.return_value, - options=mock.sentinel.options, + # Check the channel call. + secure_channel.assert_called_once_with( + target, + composite_channel_credentials.return_value, + options=mock.sentinel.options, + ) + assert channel == secure_channel.return_value + + def test_secure_authorized_channel_explicit_ssl( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + ssl_credentials = mock.Mock() + + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, ssl_credentials=ssl_credentials + ) + + # Since explicit SSL credentials are provided, get_client_ssl_credentials + # shouldn't be called. + assert not get_client_ssl_credentials.called + + # Check the ssl channel call. + assert not ssl_channel_credentials.called + + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_credentials, metadata_call_credentials.return_value + ) + + def test_secure_authorized_channel_mutual_exclusive( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + ssl_credentials = mock.Mock() + client_cert_callback = mock.Mock() + + with pytest.raises(ValueError): + google.auth.transport.grpc.secure_authorized_channel( + credentials, + request, + target, + ssl_credentials=ssl_credentials, + client_cert_callback=client_cert_callback, + ) + + def test_secure_authorized_channel_with_client_cert_callback_success( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + client_cert_callback = mock.Mock() + client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) + + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, client_cert_callback=client_cert_callback + ) + + client_cert_callback.assert_called_once() + + # Check we are using the cert and key provided by client_cert_callback. + ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + ) + + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_channel_credentials.return_value, metadata_call_credentials.return_value + ) + + @mock.patch( + "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True + ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True ) - assert channel == secure_channel.return_value + def test_secure_authorized_channel_with_client_cert_callback_failure( + self, + check_dca_metadata_path, + read_dca_metadata_file, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + + client_cert_callback = mock.Mock() + client_cert_callback.side_effect = Exception("callback exception") + + with pytest.raises(Exception) as excinfo: + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, client_cert_callback=client_cert_callback + ) + + assert str(excinfo.value) == "callback exception" -@mock.patch("grpc.composite_channel_credentials", autospec=True) -@mock.patch("grpc.metadata_call_credentials", autospec=True) @mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("grpc.secure_channel", autospec=True) -def test_secure_authorized_channel_explicit_ssl( - secure_channel, - ssl_channel_credentials, - metadata_call_credentials, - composite_channel_credentials, -): - credentials = mock.Mock() - request = mock.Mock() - target = "example.com:80" - ssl_credentials = mock.Mock() - - google.auth.transport.grpc.secure_authorized_channel( - credentials, request, target, ssl_credentials=ssl_credentials - ) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True) +@mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True +) +class TestSslCredentials(object): + def test_no_context_aware_metadata( + self, + mock_check_dca_metadata_path, + mock_read_dca_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + # Mock that the metadata file doesn't exist. + mock_check_dca_metadata_path.return_value = None + + ssl_credentials = google.auth.transport.grpc.SslCredentials() + + # Since no context aware metadata is found, we wouldn't call + # get_client_ssl_credentials, and the SSL channel credentials created is + # non mTLS. + assert ssl_credentials.ssl_credentials is not None + assert not ssl_credentials.is_mtls + mock_get_client_ssl_credentials.assert_not_called() + mock_ssl_channel_credentials.assert_called_once_with() + + def test_get_client_ssl_credentials_failure( + self, + mock_check_dca_metadata_path, + mock_read_dca_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + mock_check_dca_metadata_path.return_value = METADATA_PATH + mock_read_dca_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + + # Mock that client cert and key are not loaded and exception is raised. + mock_get_client_ssl_credentials.side_effect = ValueError() + + with pytest.raises(ValueError): + assert google.auth.transport.grpc.SslCredentials().ssl_credentials + + def test_get_client_ssl_credentials_success( + self, + mock_check_dca_metadata_path, + mock_read_dca_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + mock_check_dca_metadata_path.return_value = METADATA_PATH + mock_read_dca_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + mock_get_client_ssl_credentials.return_value = ( + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + ) - # Check the ssl channel call. - assert not ssl_channel_credentials.called + ssl_credentials = google.auth.transport.grpc.SslCredentials() - # Check the composite credentials call. - composite_channel_credentials.assert_called_once_with( - ssl_credentials, metadata_call_credentials.return_value - ) + assert ssl_credentials.ssl_credentials is not None + assert ssl_credentials.is_mtls + mock_get_client_ssl_credentials.assert_called_once() + mock_ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + )