Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[SDK-3714] Async token verifier #445

Merged
merged 4 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions auth0/v3/authentication/async_token_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Token Verifier module"""
from .. import TokenValidationError
from ..rest_async import AsyncRestClient
from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier


class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
"""Async verifier for RSA signatures, which rely on public key certificates.

Args:
jwks_url (str): The url where the JWK set is located.
algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
"""

def __init__(self, jwks_url, algorithm="RS256"):
super(AsyncAsymmetricSignatureVerifier, self).__init__(jwks_url, algorithm)
self._fetcher = AsyncJwksFetcher(jwks_url)

def set_session(self, session):
"""Set Client Session to improve performance by reusing session.

Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._fetcher.set_session(session)

async def _fetch_key(self, key_id=None):
"""Request the JWKS.

Args:
key_id (str): The key's key id."""
return await self._fetcher.get_key(key_id)

async def verify_signature(self, token):
"""Verifies the signature of the given JSON web token.

Args:
token (str): The JWT to get its signature verified.

Raises:
TokenValidationError: if the token cannot be decoded, the algorithm is invalid
or the token's signature doesn't match the calculated one.
"""
kid = self._get_kid(token)
secret_or_certificate = await self._fetch_key(key_id=kid)

return self._decode_jwt(token, secret_or_certificate)


class AsyncJwksFetcher(JwksFetcher):
"""Class that async fetches and holds a JSON web key set.
This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it.

Args:
jwks_url (str): The url where the JWK set is located.
cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
"""

def __init__(self, *args, **kwargs):
super(AsyncJwksFetcher, self).__init__(*args, **kwargs)
self._async_client = AsyncRestClient(None)

def set_session(self, session):
"""Set Client Session to improve performance by reusing session.

Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._async_client.set_session(session)

async def _fetch_jwks(self, force=False):
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
When not, it will perform a network request to the jwks_url to obtain a fresh result
and update the cache value with it.

Args:
force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False.
"""
if force or self._cache_expired():
self._cache_value = {}
try:
jwks = await self._async_client.get(self._jwks_url)
self._cache_jwks(jwks)
except: # noqa: E722
return self._cache_value
return self._cache_value

self._cache_is_fresh = False
return self._cache_value

async def get_key(self, key_id):
"""Obtains the JWK associated with the given key id.

Args:
key_id (str): The id of the key to fetch.

Returns:
the JWK associated with the given key id.

Raises:
TokenValidationError: when a key with that id cannot be found
"""
keys = await self._fetch_jwks()

if keys and key_id in keys:
return keys[key_id]

if not self._cache_is_fresh:
keys = await self._fetch_jwks(force=True)
if keys and key_id in keys:
return keys[key_id]
raise TokenValidationError(
'RSA Public Key with ID "{}" was not found.'.format(key_id)
)


class AsyncTokenVerifier(TokenVerifier):
"""Class that verifies ID tokens following the steps defined in the OpenID Connect spec.
An OpenID Connect ID token is not meant to be consumed until it's verified.

Args:
signature_verifier (AsyncAsymmetricSignatureVerifier): The instance that knows how to verify the signature.
issuer (str): The expected issuer claim value.
audience (str): The expected audience claim value.
leeway (int, optional): The clock skew to accept when verifying date related claims in seconds.
Defaults to 60 seconds.
"""

def __init__(self, signature_verifier, issuer, audience, leeway=0):
if not signature_verifier or not isinstance(
signature_verifier, AsyncAsymmetricSignatureVerifier
):
raise TypeError(
"signature_verifier must be an instance of AsyncAsymmetricSignatureVerifier."
)

self.iss = issuer
self.aud = audience
self.leeway = leeway
self._sv = signature_verifier
self._clock = None # legacy testing requirement

def set_session(self, session):
"""Set Client Session to improve performance by reusing session.

Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._sv.set_session(session)

async def verify(self, token, nonce=None, max_age=None, organization=None):
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.

Args:
token (str): The JWT to verify.
nonce (str, optional): The nonce value sent during authentication.
max_age (int, optional): The max_age value sent during authentication.
organization (str, optional): The expected organization ID (org_id) claim value. This should be specified
when logging in to an organization.

Returns:
the decoded payload from the token

Raises:
TokenValidationError: when the token cannot be decoded, the token signing algorithm is not the expected one,
the token signature is invalid or the token has a claim missing or with unexpected value.
"""

# Verify token presence
if not token or not isinstance(token, str):
raise TokenValidationError("ID token is required but missing.")

# Verify algorithm and signature
payload = await self._sv.verify_signature(token)

# Verify claims
self._verify_payload(payload, nonce, max_age, organization)

return payload
78 changes: 58 additions & 20 deletions auth0/v3/authentication/token_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,18 @@ def _fetch_key(self, key_id=None):
"""
raise NotImplementedError

def verify_signature(self, token):
"""Verifies the signature of the given JSON web token.
def _get_kid(self, token):
"""Gets the key id from the kid claim of the header of the token

Args:
token (str): The JWT to get its signature verified.
token (str): The JWT to get the header from.

Raises:
TokenValidationError: if the token cannot be decoded, the algorithm is invalid
or the token's signature doesn't match the calculated one.

Returns:
the key id or None
"""
try:
header = jwt.get_unverified_header(token)
Expand All @@ -67,9 +70,19 @@ def verify_signature(self, token):
'to be signed with "{}"'.format(alg, self._algorithm)
)

kid = header.get("kid", None)
secret_or_certificate = self._fetch_key(key_id=kid)
return header.get("kid", None)

def _decode_jwt(self, token, secret_or_certificate):
"""Verifies and decodes the given JSON web token with the given public key or shared secret.

Args:
token (str): The JWT to get its signature verified.
secret_or_certificate (str): The public key or shared secret.

Raises:
TokenValidationError: if the token cannot be decoded, the algorithm is invalid
or the token's signature doesn't match the calculated one.
"""
try:
decoded = jwt.decode(
jwt=token,
Expand All @@ -81,6 +94,21 @@ def verify_signature(self, token):
raise TokenValidationError("Invalid token signature.")
return decoded

def verify_signature(self, token):
"""Verifies the signature of the given JSON web token.

Args:
token (str): The JWT to get its signature verified.

Raises:
TokenValidationError: if the token cannot be decoded, the algorithm is invalid
or the token's signature doesn't match the calculated one.
"""
kid = self._get_kid(token)
secret_or_certificate = self._fetch_key(key_id=kid)

return self._decode_jwt(token, secret_or_certificate)


class SymmetricSignatureVerifier(SignatureVerifier):
"""Verifier for HMAC signatures, which rely on shared secrets.
Expand Down Expand Up @@ -136,6 +164,24 @@ def _init_cache(self, cache_ttl):
self._cache_ttl = cache_ttl
self._cache_is_fresh = False

def _cache_expired(self):
"""Checks if the cache is expired

Returns:
True if it should use the cache.
"""
return self._cache_date + self._cache_ttl < time.time()

def _cache_jwks(self, jwks):
"""Cache the response of the JWKS request

Args:
jwks (dict): The JWKS
"""
self._cache_value = self._parse_jwks(jwks)
self._cache_is_fresh = True
self._cache_date = time.time()

def _fetch_jwks(self, force=False):
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
When not, it will perform a network request to the jwks_url to obtain a fresh result
Expand All @@ -144,23 +190,15 @@ def _fetch_jwks(self, force=False):
Args:
force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False.
"""
has_expired = self._cache_date + self._cache_ttl < time.time()

if not force and not has_expired:
# Return from cache
self._cache_is_fresh = False
if force or self._cache_expired():
self._cache_value = {}
response = requests.get(self._jwks_url)
if response.ok:
jwks = response.json()
self._cache_jwks(jwks)
return self._cache_value

# Invalidate cache and fetch fresh data
self._cache_value = {}
response = requests.get(self._jwks_url)

if response.ok:
# Update cache
jwks = response.json()
self._cache_value = self._parse_jwks(jwks)
self._cache_is_fresh = True
self._cache_date = time.time()
self._cache_is_fresh = False
return self._cache_value

@staticmethod
Expand Down
5 changes: 1 addition & 4 deletions auth0/v3/rest_async.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import asyncio
import json

import aiohttp

from auth0.v3.exceptions import RateLimitError

from .rest import EmptyResponse, JsonResponse, PlainResponse
from .rest import Response as _Response
from .rest import RestClient
from .rest import EmptyResponse, JsonResponse, PlainResponse, RestClient


def _clean_params(params):
Expand Down
Loading