diff --git a/auth0/v3/authentication/async_token_verifier.py b/auth0/v3/authentication/async_token_verifier.py new file mode 100644 index 00000000..11d0f995 --- /dev/null +++ b/auth0/v3/authentication/async_token_verifier.py @@ -0,0 +1,182 @@ +"""Token Verifier module""" +from .. import TokenValidationError +from ..rest_async import AsyncRestClient +from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier + + +class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier): + """Async verifier for RSA signatures, which rely on public key certificates. + + Args: + jwks_url (str): The url where the JWK set is located. + algorithm (str, optional): The expected signing algorithm. Defaults to "RS256". + """ + + def __init__(self, jwks_url, algorithm="RS256"): + super(AsyncAsymmetricSignatureVerifier, self).__init__(jwks_url, algorithm) + self._fetcher = AsyncJwksFetcher(jwks_url) + + def set_session(self, session): + """Set Client Session to improve performance by reusing session. + + Args: + session (aiohttp.ClientSession): The client session which should be closed + manually or within context manager. + """ + self._fetcher.set_session(session) + + async def _fetch_key(self, key_id=None): + """Request the JWKS. + + Args: + key_id (str): The key's key id.""" + return await self._fetcher.get_key(key_id) + + async def verify_signature(self, token): + """Verifies the signature of the given JSON web token. + + Args: + token (str): The JWT to get its signature verified. + + Raises: + TokenValidationError: if the token cannot be decoded, the algorithm is invalid + or the token's signature doesn't match the calculated one. + """ + kid = self._get_kid(token) + secret_or_certificate = await self._fetch_key(key_id=kid) + + return self._decode_jwt(token, secret_or_certificate) + + +class AsyncJwksFetcher(JwksFetcher): + """Class that async fetches and holds a JSON web key set. + This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it. + + Args: + jwks_url (str): The url where the JWK set is located. + cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. + """ + + def __init__(self, *args, **kwargs): + super(AsyncJwksFetcher, self).__init__(*args, **kwargs) + self._async_client = AsyncRestClient(None) + + def set_session(self, session): + """Set Client Session to improve performance by reusing session. + + Args: + session (aiohttp.ClientSession): The client session which should be closed + manually or within context manager. + """ + self._async_client.set_session(session) + + async def _fetch_jwks(self, force=False): + """Attempts to obtain the JWK set from the cache, as long as it's still valid. + When not, it will perform a network request to the jwks_url to obtain a fresh result + and update the cache value with it. + + Args: + force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False. + """ + if force or self._cache_expired(): + self._cache_value = {} + try: + jwks = await self._async_client.get(self._jwks_url) + self._cache_jwks(jwks) + except: # noqa: E722 + return self._cache_value + return self._cache_value + + self._cache_is_fresh = False + return self._cache_value + + async def get_key(self, key_id): + """Obtains the JWK associated with the given key id. + + Args: + key_id (str): The id of the key to fetch. + + Returns: + the JWK associated with the given key id. + + Raises: + TokenValidationError: when a key with that id cannot be found + """ + keys = await self._fetch_jwks() + + if keys and key_id in keys: + return keys[key_id] + + if not self._cache_is_fresh: + keys = await self._fetch_jwks(force=True) + if keys and key_id in keys: + return keys[key_id] + raise TokenValidationError( + 'RSA Public Key with ID "{}" was not found.'.format(key_id) + ) + + +class AsyncTokenVerifier(TokenVerifier): + """Class that verifies ID tokens following the steps defined in the OpenID Connect spec. + An OpenID Connect ID token is not meant to be consumed until it's verified. + + Args: + signature_verifier (AsyncAsymmetricSignatureVerifier): The instance that knows how to verify the signature. + issuer (str): The expected issuer claim value. + audience (str): The expected audience claim value. + leeway (int, optional): The clock skew to accept when verifying date related claims in seconds. + Defaults to 60 seconds. + """ + + def __init__(self, signature_verifier, issuer, audience, leeway=0): + if not signature_verifier or not isinstance( + signature_verifier, AsyncAsymmetricSignatureVerifier + ): + raise TypeError( + "signature_verifier must be an instance of AsyncAsymmetricSignatureVerifier." + ) + + self.iss = issuer + self.aud = audience + self.leeway = leeway + self._sv = signature_verifier + self._clock = None # legacy testing requirement + + def set_session(self, session): + """Set Client Session to improve performance by reusing session. + + Args: + session (aiohttp.ClientSession): The client session which should be closed + manually or within context manager. + """ + self._sv.set_session(session) + + async def verify(self, token, nonce=None, max_age=None, organization=None): + """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. + + Args: + token (str): The JWT to verify. + nonce (str, optional): The nonce value sent during authentication. + max_age (int, optional): The max_age value sent during authentication. + organization (str, optional): The expected organization ID (org_id) claim value. This should be specified + when logging in to an organization. + + Returns: + the decoded payload from the token + + Raises: + TokenValidationError: when the token cannot be decoded, the token signing algorithm is not the expected one, + the token signature is invalid or the token has a claim missing or with unexpected value. + """ + + # Verify token presence + if not token or not isinstance(token, str): + raise TokenValidationError("ID token is required but missing.") + + # Verify algorithm and signature + payload = await self._sv.verify_signature(token) + + # Verify claims + self._verify_payload(payload, nonce, max_age, organization) + + return payload diff --git a/auth0/v3/authentication/token_verifier.py b/auth0/v3/authentication/token_verifier.py index 17b040b9..5e44e5d2 100644 --- a/auth0/v3/authentication/token_verifier.py +++ b/auth0/v3/authentication/token_verifier.py @@ -45,15 +45,18 @@ def _fetch_key(self, key_id=None): """ raise NotImplementedError - def verify_signature(self, token): - """Verifies the signature of the given JSON web token. + def _get_kid(self, token): + """Gets the key id from the kid claim of the header of the token Args: - token (str): The JWT to get its signature verified. + token (str): The JWT to get the header from. Raises: TokenValidationError: if the token cannot be decoded, the algorithm is invalid or the token's signature doesn't match the calculated one. + + Returns: + the key id or None """ try: header = jwt.get_unverified_header(token) @@ -67,9 +70,19 @@ def verify_signature(self, token): 'to be signed with "{}"'.format(alg, self._algorithm) ) - kid = header.get("kid", None) - secret_or_certificate = self._fetch_key(key_id=kid) + return header.get("kid", None) + + def _decode_jwt(self, token, secret_or_certificate): + """Verifies and decodes the given JSON web token with the given public key or shared secret. + + Args: + token (str): The JWT to get its signature verified. + secret_or_certificate (str): The public key or shared secret. + Raises: + TokenValidationError: if the token cannot be decoded, the algorithm is invalid + or the token's signature doesn't match the calculated one. + """ try: decoded = jwt.decode( jwt=token, @@ -81,6 +94,21 @@ def verify_signature(self, token): raise TokenValidationError("Invalid token signature.") return decoded + def verify_signature(self, token): + """Verifies the signature of the given JSON web token. + + Args: + token (str): The JWT to get its signature verified. + + Raises: + TokenValidationError: if the token cannot be decoded, the algorithm is invalid + or the token's signature doesn't match the calculated one. + """ + kid = self._get_kid(token) + secret_or_certificate = self._fetch_key(key_id=kid) + + return self._decode_jwt(token, secret_or_certificate) + class SymmetricSignatureVerifier(SignatureVerifier): """Verifier for HMAC signatures, which rely on shared secrets. @@ -136,6 +164,24 @@ def _init_cache(self, cache_ttl): self._cache_ttl = cache_ttl self._cache_is_fresh = False + def _cache_expired(self): + """Checks if the cache is expired + + Returns: + True if it should use the cache. + """ + return self._cache_date + self._cache_ttl < time.time() + + def _cache_jwks(self, jwks): + """Cache the response of the JWKS request + + Args: + jwks (dict): The JWKS + """ + self._cache_value = self._parse_jwks(jwks) + self._cache_is_fresh = True + self._cache_date = time.time() + def _fetch_jwks(self, force=False): """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result @@ -144,23 +190,15 @@ def _fetch_jwks(self, force=False): Args: force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False. """ - has_expired = self._cache_date + self._cache_ttl < time.time() - - if not force and not has_expired: - # Return from cache - self._cache_is_fresh = False + if force or self._cache_expired(): + self._cache_value = {} + response = requests.get(self._jwks_url) + if response.ok: + jwks = response.json() + self._cache_jwks(jwks) return self._cache_value - # Invalidate cache and fetch fresh data - self._cache_value = {} - response = requests.get(self._jwks_url) - - if response.ok: - # Update cache - jwks = response.json() - self._cache_value = self._parse_jwks(jwks) - self._cache_is_fresh = True - self._cache_date = time.time() + self._cache_is_fresh = False return self._cache_value @staticmethod diff --git a/auth0/v3/rest_async.py b/auth0/v3/rest_async.py index 40493930..7648c5b5 100644 --- a/auth0/v3/rest_async.py +++ b/auth0/v3/rest_async.py @@ -1,13 +1,10 @@ import asyncio -import json import aiohttp from auth0.v3.exceptions import RateLimitError -from .rest import EmptyResponse, JsonResponse, PlainResponse -from .rest import Response as _Response -from .rest import RestClient +from .rest import EmptyResponse, JsonResponse, PlainResponse, RestClient def _clean_params(params): diff --git a/auth0/v3/test_async/test_async_token_verifier.py b/auth0/v3/test_async/test_async_token_verifier.py new file mode 100644 index 00000000..fb6d0e26 --- /dev/null +++ b/auth0/v3/test_async/test_async_token_verifier.py @@ -0,0 +1,275 @@ +import time +import unittest + +import jwt +from aioresponses import aioresponses +from callee import Attrs +from cryptography.hazmat.primitives import serialization +from mock import ANY + +from .. import TokenValidationError +from ..authentication.async_token_verifier import ( + AsyncAsymmetricSignatureVerifier, + AsyncJwksFetcher, + AsyncTokenVerifier, +) +from ..test.authentication.test_token_verifier import ( + JWKS_RESPONSE_MULTIPLE_KEYS, + JWKS_RESPONSE_SINGLE_KEY, + RSA_PUB_KEY_1_JWK, + RSA_PUB_KEY_1_PEM, + RSA_PUB_KEY_2_PEM, +) +from .test_asyncify import get_callback + +JWKS_URI = "https://example.auth0.com/.well-known/jwks.json" + +PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDfytWVSk/4Z6rNu8UZ7C4tnU9x0vj5FCaj4awKZlxVgOR1Kcen +QqDOxJdrXXanTBJbZwh8pk+HpWvqDVgVmKhnt+OkgF//hIXZoJMhDOFVzX504kiZ +cu3bu7kFs+PUfKw5s59tmETFPseA/fIrad9YXHisMkNmPWhuKYJ3WfZAaQIDAQAB +AoGADPSfHL9qlcTanIJsTK3hln5u5PYDt9e0zPP5k7iNS93kW+wJROOUj6PN6EdG +4TSEM4ppcV3naMDo2GnhWY624P6LUB+CbDFzjQKq805vrxJuFnq50blscwVK/ffP +kODBm/gwk+FaliRpQTDAAPWkKbkRfkmPx4JMEmTDBQ45diECQQDxw3qp2+wa5WP5 +9w7AYrDPq4Fd6gIFcmxracROUcdhhMmVHKA9DzTWY46cSoWZoChYhQhhyj8dlP8q +El8aevN9AkEA7PhxcNyff8aehqEQ/Z38bm3P+GgB9EkRinjesba2CqhEI5okzvb7 +OIYdszgQUBqGKlST0a7s9KuTpd7moyy8XQJAY8hjk0HCxCMTTXMLspnJEh1eKo3P +wcHFP9wKeqzEFtrAfHuxIyJok2fJz3XuiEaTAF3/5KSdwi7h1dJ5UCuY3QJAM9rF +0CGnEWngJKu4MRdSNsP232+7Bb67hOagLJlDyp85keTYKyXmoV7PvvkEsNKtCzRI +yHiTx5KIE6LsK0bNzQJBAMV+1KyI8ua1XmqLDaOexvBPM86HnuP+8u5CthgrXyGm +nh9gurwbs/lBRYV/d4XBLj+dzHb2zC0Jo7u96wrOObw= +-----END RSA PRIVATE KEY-----""" + +PUBLIC_KEY = { + "kty": "RSA", + "e": "AQAB", + "kid": "kid-1", + "n": "38rVlUpP-GeqzbvFGewuLZ1PcdL4-RQmo-GsCmZcVYDkdSnHp0KgzsSXa112p0wSW2cIfKZPh6Vr6g1YFZioZ7fjpIBf_4SF2aCTIQzhVc1-dOJImXLt27u5BbPj1HysObOfbZhExT7HgP3yK2nfWFx4rDJDZj1obimCd1n2QGk", +} + + +def get_pem_bytes(rsa_public_key): + return rsa_public_key.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo + ) + + +class TestAsyncAsymmetricSignatureVerifier(unittest.IsolatedAsyncioTestCase): + @aioresponses() + async def test_async_asymmetric_verifier_fetches_key(self, mocked): + callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY) + mocked.get(JWKS_URI, callback=callback) + + verifier = AsyncAsymmetricSignatureVerifier(JWKS_URI) + + key = await verifier._fetch_key("test-key-1") + + self.assertEqual(get_pem_bytes(key), RSA_PUB_KEY_1_PEM) + + +class TestAsyncJwksFetcher(unittest.IsolatedAsyncioTestCase): + @aioresponses() + async def test_async_get_jwks_json_twice_on_cache_expired(self, mocked): + fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1) + + callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY) + mocked.get(JWKS_URI, callback=callback) + mocked.get(JWKS_URI, callback=callback) + + key_1 = await fetcher.get_key("test-key-1") + expected_key_1_pem = get_pem_bytes(key_1) + self.assertEqual(expected_key_1_pem, RSA_PUB_KEY_1_PEM) + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual(mock.call_count, 1) + + time.sleep(2) + + # 2 seconds has passed, cache should be expired + key_1 = await fetcher.get_key("test-key-1") + expected_key_1_pem = get_pem_bytes(key_1) + self.assertEqual(expected_key_1_pem, RSA_PUB_KEY_1_PEM) + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual(mock.call_count, 2) + + @aioresponses() + async def test_async_get_jwks_json_once_on_cache_hit(self, mocked): + fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1) + + callback, mock = get_callback(200, JWKS_RESPONSE_MULTIPLE_KEYS) + mocked.get(JWKS_URI, callback=callback) + mocked.get(JWKS_URI, callback=callback) + + key_1 = await fetcher.get_key("test-key-1") + key_2 = await fetcher.get_key("test-key-2") + expected_key_1_pem = get_pem_bytes(key_1) + expected_key_2_pem = get_pem_bytes(key_2) + self.assertEqual(expected_key_1_pem, RSA_PUB_KEY_1_PEM) + self.assertEqual(expected_key_2_pem, RSA_PUB_KEY_2_PEM) + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual(mock.call_count, 1) + + @aioresponses() + async def test_async_fetches_jwks_json_forced_on_cache_miss(self, mocked): + fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1) + + callback, mock = get_callback(200, {"keys": [RSA_PUB_KEY_1_JWK]}) + mocked.get(JWKS_URI, callback=callback) + + # Triggers the first call + key_1 = await fetcher.get_key("test-key-1") + expected_key_1_pem = get_pem_bytes(key_1) + self.assertEqual(expected_key_1_pem, RSA_PUB_KEY_1_PEM) + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual(mock.call_count, 1) + + callback, mock = get_callback(200, JWKS_RESPONSE_MULTIPLE_KEYS) + mocked.get(JWKS_URI, callback=callback) + + # Triggers the second call + key_2 = await fetcher.get_key("test-key-2") + expected_key_2_pem = get_pem_bytes(key_2) + self.assertEqual(expected_key_2_pem, RSA_PUB_KEY_2_PEM) + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual(mock.call_count, 1) + + @aioresponses() + async def test_async_fetches_jwks_json_once_on_cache_miss(self, mocked): + fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1) + + callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY) + mocked.get(JWKS_URI, callback=callback) + + with self.assertRaises(Exception) as err: + await fetcher.get_key("missing-key") + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual( + str(err.exception), 'RSA Public Key with ID "missing-key" was not found.' + ) + self.assertEqual(mock.call_count, 1) + + @aioresponses() + async def test_async_fails_to_fetch_jwks_json_after_retrying_twice(self, mocked): + fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1) + + callback, mock = get_callback(500, {}) + mocked.get(JWKS_URI, callback=callback) + mocked.get(JWKS_URI, callback=callback) + + with self.assertRaises(Exception) as err: + await fetcher.get_key("id1") + + mock.assert_called_with( + Attrs(path="/.well-known/jwks.json"), + allow_redirects=True, + params=None, + headers=ANY, + timeout=ANY, + ) + self.assertEqual( + str(err.exception), 'RSA Public Key with ID "id1" was not found.' + ) + self.assertEqual(mock.call_count, 2) + + +class TestAsyncTokenVerifier(unittest.IsolatedAsyncioTestCase): + @aioresponses() + async def test_RS256_token_signature_passes(self, mocked): + callback, mock = get_callback(200, {"keys": [PUBLIC_KEY]}) + mocked.get(JWKS_URI, callback=callback) + + issuer = "https://tokens-test.auth0.com/" + audience = "tokens-test-123" + token = jwt.encode( + { + "iss": issuer, + "sub": "auth0|123456789", + "aud": audience, + "exp": int(time.time()) + 86400, + "iat": int(time.time()), + }, + PRIVATE_KEY, + algorithm="RS256", + headers={"kid": "kid-1"}, + ) + + tv = AsyncTokenVerifier( + signature_verifier=AsyncAsymmetricSignatureVerifier(JWKS_URI), + issuer=issuer, + audience=audience, + ) + payload = await tv.verify(token) + self.assertEqual(payload["sub"], "auth0|123456789") + + @aioresponses() + async def test_RS256_token_signature_fails(self, mocked): + callback, mock = get_callback( + 200, {"keys": [RSA_PUB_KEY_1_JWK]} + ) # different pub key + mocked.get(JWKS_URI, callback=callback) + + issuer = "https://tokens-test.auth0.com/" + audience = "tokens-test-123" + token = jwt.encode( + { + "iss": issuer, + "sub": "auth0|123456789", + "aud": audience, + "exp": int(time.time()) + 86400, + "iat": int(time.time()), + }, + PRIVATE_KEY, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + tv = AsyncTokenVerifier( + signature_verifier=AsyncAsymmetricSignatureVerifier(JWKS_URI), + issuer=issuer, + audience=audience, + ) + + with self.assertRaises(TokenValidationError) as err: + await tv.verify(token) + self.assertEqual(str(err.exception), "Invalid token signature.") diff --git a/auth0/v3/test_async/test_asyncify.py b/auth0/v3/test_async/test_asyncify.py index f8a7a0c5..439f61c1 100644 --- a/auth0/v3/test_async/test_asyncify.py +++ b/auth0/v3/test_async/test_asyncify.py @@ -39,8 +39,10 @@ } -def get_callback(status=200): - mock = MagicMock(return_value=CallbackResult(status=status, payload=payload)) +def get_callback(status=200, response=None): + mock = MagicMock( + return_value=CallbackResult(status=status, payload=response or payload) + ) def callback(url, **kwargs): return mock(url, **kwargs)