Skip to content

Commit

Permalink
feat: add helper func to for default encrypted cert (#514)
Browse files Browse the repository at this point in the history
* feat: helper func to for default encrpted cert
  • Loading branch information
arithmetic1728 authored May 28, 2020
1 parent eb7be3f commit f282aa4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
42 changes: 42 additions & 0 deletions google/auth/transport/mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,45 @@ def callback():
return cert_bytes, key_bytes

return callback


def default_client_encrypted_cert_source(cert_path, key_path):
"""Get a callback which returns the default encrpyted client SSL credentials.
Args:
cert_path (str): The cert file path. The default client certificate will
be written to this file when the returned callback is called.
key_path (str): The key file path. The default encrypted client key will
be written to this file when the returned callback is called.
Returns:
Callable[[], [str, str, bytes]]: A callback which generates the default
client certificate, encrpyted private key and passphrase. It writes
the certificate and private key into the cert_path and key_path, and
returns the cert_path, key_path and passphrase bytes.
Raises:
google.auth.exceptions.DefaultClientCertSourceError: If any problem
occurs when loading or saving the client certificate and key.
"""
if not has_default_client_cert_source():
raise exceptions.MutualTLSChannelError(
"Default client encrypted cert source doesn't exist"
)

def callback():
try:
_, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials(
generate_encrypted_key=True
)
with open(cert_path, "wb") as cert_file:
cert_file.write(cert_bytes)
with open(key_path, "wb") as key_file:
key_file.write(key_bytes)
except (exceptions.ClientCertError, OSError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
six.raise_from(new_exc, caught_exc)

return cert_path, key_path, passphrase_bytes

return callback
28 changes: 28 additions & 0 deletions tests/transport/test_mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,31 @@ def test_default_client_cert_source(
callback = mtls.default_client_cert_source()
with pytest.raises(exceptions.MutualTLSChannelError):
callback()


@mock.patch(
"google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
)
@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
def test_default_client_encrypted_cert_source(
has_default_client_cert_source, get_client_ssl_credentials
):
# Test default client cert source doesn't exist.
has_default_client_cert_source.return_value = False
with pytest.raises(exceptions.MutualTLSChannelError):
mtls.default_client_encrypted_cert_source("cert_path", "key_path")

# The following tests will assume default client cert source exists.
has_default_client_cert_source.return_value = True

# Test good callback.
get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase")
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()):
assert callback() == ("cert_path", "key_path", b"passphrase")

# Test bad callback which throws exception.
get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
with pytest.raises(exceptions.MutualTLSChannelError):
callback()

0 comments on commit f282aa4

Please # to comment.