diff --git a/jose/backends/cryptography_backend.py b/jose/backends/cryptography_backend.py index 945349b..291bc8b 100644 --- a/jose/backends/cryptography_backend.py +++ b/jose/backends/cryptography_backend.py @@ -17,6 +17,7 @@ from ..constants import ALGORITHMS from ..exceptions import JWEError, JWKError from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64 +from ..utils import is_pem_format, is_ssh_key from .base import Key _binding = None @@ -555,14 +556,7 @@ def __init__(self, key, algorithm): if isinstance(key, str): key = key.encode("utf-8") - invalid_strings = [ - b"-----BEGIN PUBLIC KEY-----", - b"-----BEGIN RSA PUBLIC KEY-----", - b"-----BEGIN CERTIFICATE-----", - b"ssh-rsa", - ] - - if any(string_value in key for string_value in invalid_strings): + if is_pem_format(key) or is_ssh_key(key): raise JWKError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." diff --git a/jose/backends/native.py b/jose/backends/native.py index eb3a6ae..f54d739 100644 --- a/jose/backends/native.py +++ b/jose/backends/native.py @@ -6,6 +6,7 @@ from jose.constants import ALGORITHMS from jose.exceptions import JWKError from jose.utils import base64url_decode, base64url_encode +from jose.utils import is_pem_format, is_ssh_key def get_random_bytes(num_bytes): @@ -36,14 +37,7 @@ def __init__(self, key, algorithm): if isinstance(key, str): key = key.encode("utf-8") - invalid_strings = [ - b"-----BEGIN PUBLIC KEY-----", - b"-----BEGIN RSA PUBLIC KEY-----", - b"-----BEGIN CERTIFICATE-----", - b"ssh-rsa", - ] - - if any(string_value in key for string_value in invalid_strings): + if is_pem_format(key) or is_ssh_key(key): raise JWKError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." diff --git a/jose/jwt.py b/jose/jwt.py index b364b4b..ae7e5b3 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -141,6 +141,14 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None verify_signature = defaults.get("verify_signature", True) + # Forbid the usage of the jwt.decode without alogrightms parameter + # See https://github.com/mpdavis/python-jose/issues/346 for more + # information CVE-2024-33663 + if verify_signature and algorithms is None: + raise JWTError("It is required that you pass in a value for " + 'the "algorithms" argument when calling ' + "decode().") + try: payload = jws.verify(token, key, algorithms, verify=verify_signature) except JWSError as e: diff --git a/jose/utils.py b/jose/utils.py index d04c4ac..3114e45 100644 --- a/jose/utils.py +++ b/jose/utils.py @@ -1,3 +1,4 @@ +import re import base64 import struct @@ -105,3 +106,75 @@ def ensure_binary(s): if isinstance(s, str): return s.encode("utf-8", "strict") raise TypeError(f"not expecting type '{type(s)}'") + + +# Based on https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc +# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 +_PEMS = { + b"CERTIFICATE", + b"TRUSTED CERTIFICATE", + b"PRIVATE KEY", + b"PUBLIC KEY", + b"ENCRYPTED PRIVATE KEY", + b"OPENSSH PRIVATE KEY", + b"DSA PRIVATE KEY", + b"RSA PRIVATE KEY", + b"RSA PUBLIC KEY", + b"EC PRIVATE KEY", + b"DH PARAMETERS", + b"NEW CERTIFICATE REQUEST", + b"CERTIFICATE REQUEST", + b"SSH2 PUBLIC KEY", + b"SSH2 ENCRYPTED PRIVATE KEY", + b"X509 CRL", +} + + +_PEM_RE = re.compile( + b"----[- ]BEGIN (" + + b"|".join(_PEMS) + + b""")[- ]----\r? +.+?\r? +----[- ]END \\1[- ]----\r?\n?""", + re.DOTALL, +) + + +def is_pem_format(key): + """ + Return True if the key is PEM format + This function uses the list of valid PEM headers defined in + _PEMS dict. + """ + return bool(_PEM_RE.search(key)) + + +# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 +_CERT_SUFFIX = b"-cert-v01@openssh.com" +_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") +_SSH_KEY_FORMATS = [ + b"ssh-ed25519", + b"ssh-rsa", + b"ssh-dss", + b"ecdsa-sha2-nistp256", + b"ecdsa-sha2-nistp384", + b"ecdsa-sha2-nistp521", +] + + +def is_ssh_key(key): + """ + Return True if the key is a SSH key + This function uses the list of valid SSH key format defined in + _SSH_KEY_FORMATS dict. + """ + if any(string_value in key for string_value in _SSH_KEY_FORMATS): + return True + + ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) + if ssh_pubkey_match: + key_type = ssh_pubkey_match.group(1) + if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: + return True + + return False diff --git a/tests/algorithms/test_HMAC.py b/tests/algorithms/test_HMAC.py index 2b0859e..15c1cb7 100644 --- a/tests/algorithms/test_HMAC.py +++ b/tests/algorithms/test_HMAC.py @@ -14,14 +14,17 @@ def test_non_string_key(self): def test_RSA_key(self): key = "-----BEGIN PUBLIC KEY-----" + key += "\n\n\n-----END PUBLIC KEY-----" with pytest.raises(JOSEError): HMACKey(key, ALGORITHMS.HS256) key = "-----BEGIN RSA PUBLIC KEY-----" + key += "\n\n\n-----END RSA PUBLIC KEY-----" with pytest.raises(JOSEError): HMACKey(key, ALGORITHMS.HS256) key = "-----BEGIN CERTIFICATE-----" + key += "\n\n\n-----END CERTIFICATE-----" with pytest.raises(JOSEError): HMACKey(key, ALGORITHMS.HS256) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 8c2e262..d1acbc1 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -5,7 +5,8 @@ import pytest from jose import jws, jwt -from jose.exceptions import JWTError +from jose.constants import ALGORITHMS +from jose.exceptions import JWTError, JWKError @pytest.fixture @@ -56,7 +57,7 @@ def test_no_alg(self, claims, key): ], ) def test_numeric_key(self, key, token): - token_info = jwt.decode(token, key) + token_info = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED) assert token_info == {"name": "test"} def test_invalid_claims_json(self): @@ -108,7 +109,7 @@ def test_no_alg_default_headers(self, claims, key, headers): def test_non_default_headers(self, claims, key, headers): encoded = jwt.encode(claims, key, headers=headers) - decoded = jwt.decode(encoded, key) + decoded = jwt.decode(encoded, key, algorithms=ALGORITHMS.HS256) assert claims == decoded all_headers = jwt.get_unverified_headers(encoded) for k, v in headers.items(): @@ -159,7 +160,7 @@ def test_encode(self, claims, key): def test_decode(self, claims, key): token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" ".eyJhIjoiYiJ9" ".jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" - decoded = jwt.decode(token, key) + decoded = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED) assert decoded == claims @@ -190,7 +191,7 @@ def test_leeway_is_timedelta(self, claims, key): options = {"leeway": leeway} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_iat_not_int(self, key): claims = {"iat": "test"} @@ -198,7 +199,7 @@ def test_iat_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_not_int(self, key): claims = {"nbf": "test"} @@ -206,7 +207,7 @@ def test_nbf_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_datetime(self, key): nbf = datetime.utcnow() - timedelta(seconds=5) @@ -214,7 +215,7 @@ def test_nbf_datetime(self, key): claims = {"nbf": nbf} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_with_leeway(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -226,7 +227,7 @@ def test_nbf_with_leeway(self, key): options = {"leeway": 10} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_nbf_in_future(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -236,7 +237,7 @@ def test_nbf_in_future(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_skip(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -246,11 +247,11 @@ def test_nbf_skip(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) options = {"verify_nbf": False} - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_exp_not_int(self, key): claims = {"exp": "test"} @@ -258,7 +259,7 @@ def test_exp_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_datetime(self, key): exp = datetime.utcnow() + timedelta(seconds=5) @@ -266,7 +267,7 @@ def test_exp_datetime(self, key): claims = {"exp": exp} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_with_leeway(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -278,7 +279,7 @@ def test_exp_with_leeway(self, key): options = {"leeway": 10} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_exp_in_past(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -288,7 +289,7 @@ def test_exp_in_past(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_skip(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -298,11 +299,11 @@ def test_exp_skip(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) options = {"verify_exp": False} - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_aud_string(self, key): aud = "audience" @@ -310,7 +311,7 @@ def test_aud_string(self, key): claims = {"aud": aud} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list(self, key): aud = "audience" @@ -318,7 +319,7 @@ def test_aud_list(self, key): claims = {"aud": [aud]} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list_multiple(self, key): aud = "audience" @@ -326,7 +327,7 @@ def test_aud_list_multiple(self, key): claims = {"aud": [aud, "another"]} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list_is_strings(self, key): aud = "audience" @@ -335,7 +336,7 @@ def test_aud_list_is_strings(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_case_sensitive(self, key): aud = "audience" @@ -344,13 +345,13 @@ def test_aud_case_sensitive(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience="AUDIENCE") + jwt.decode(token, key, audience="AUDIENCE", algorithms=ALGORITHMS.HS256) def test_aud_empty_claim(self, claims, key): aud = "audience" token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_not_string_or_list(self, key): aud = 1 @@ -359,7 +360,7 @@ def test_aud_not_string_or_list(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_aud_given_number(self, key): aud = "audience" @@ -368,7 +369,7 @@ def test_aud_given_number(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience=1) + jwt.decode(token, key, audience=1, algorithms=ALGORITHMS.HS256) def test_iss_string(self, key): iss = "issuer" @@ -376,7 +377,7 @@ def test_iss_string(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=iss) + jwt.decode(token, key, issuer=iss, algorithms=ALGORITHMS.HS256) def test_iss_list(self, key): iss = "issuer" @@ -384,7 +385,7 @@ def test_iss_list(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=["https://issuer", "issuer"]) + jwt.decode(token, key, issuer=["https://issuer", "issuer"], algorithms=ALGORITHMS.HS256) def test_iss_tuple(self, key): iss = "issuer" @@ -392,7 +393,7 @@ def test_iss_tuple(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=("https://issuer", "issuer")) + jwt.decode(token, key, issuer=("https://issuer", "issuer"), algorithms=ALGORITHMS.HS256) def test_iss_invalid(self, key): iss = "issuer" @@ -401,7 +402,7 @@ def test_iss_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, issuer="another") + jwt.decode(token, key, issuer="another", algorithms=ALGORITHMS.HS256) def test_sub_string(self, key): sub = "subject" @@ -409,7 +410,7 @@ def test_sub_string(self, key): claims = {"sub": sub} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_sub_invalid(self, key): sub = 1 @@ -418,7 +419,7 @@ def test_sub_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_sub_correct(self, key): sub = "subject" @@ -426,7 +427,7 @@ def test_sub_correct(self, key): claims = {"sub": sub} token = jwt.encode(claims, key) - jwt.decode(token, key, subject=sub) + jwt.decode(token, key, subject=sub, algorithms=ALGORITHMS.HS256) def test_sub_incorrect(self, key): sub = "subject" @@ -435,7 +436,7 @@ def test_sub_incorrect(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, subject="another") + jwt.decode(token, key, subject="another", algorithms=ALGORITHMS.HS256) def test_jti_string(self, key): jti = "JWT ID" @@ -443,7 +444,7 @@ def test_jti_string(self, key): claims = {"jti": jti} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_jti_invalid(self, key): jti = 1 @@ -452,33 +453,33 @@ def test_jti_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_at_hash(self, claims, key): access_token = "" token = jwt.encode(claims, key, access_token=access_token) - payload = jwt.decode(token, key, access_token=access_token) + payload = jwt.decode(token, key, access_token=access_token, algorithms=ALGORITHMS.HS256) assert "at_hash" in payload def test_at_hash_invalid(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key, access_token="") + jwt.decode(token, key, access_token="", algorithms=ALGORITHMS.HS256) def test_at_hash_missing_access_token(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_at_hash_missing_claim(self, claims, key): token = jwt.encode(claims, key) - payload = jwt.decode(token, key, access_token="") + payload = jwt.decode(token, key, access_token="", algorithms=ALGORITHMS.HS256) assert "at_hash" not in payload def test_at_hash_unable_to_calculate(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key, access_token="\xe2") + jwt.decode(token, key, access_token="\xe2", algorithms=ALGORITHMS.HS256) def test_bad_claims(self): bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck" @@ -516,9 +517,45 @@ def test_require(self, claims, key, claim, value): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, options=options, audience=str(value)) + jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256) new_claims = dict(claims) new_claims[claim] = value token = jwt.encode(new_claims, key) - jwt.decode(token, key, options=options, audience=str(value)) + jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256) + + def test_CVE_2024_33663(self): + """Test based on https://github.com/mpdavis/python-jose/issues/346""" + from Crypto.PublicKey import ECC + from Crypto.Hash import HMAC, SHA256 + + # ----- SETUP ----- + # generate an asymmetric ECC keypair + # !! signing should only be possible with the private key !! + KEY = ECC.generate(curve='P-256') + + # PUBLIC KEY, AVAILABLE TO USER + # CAN BE RECOVERED THROUGH E.G. PUBKEY RECOVERY WITH TWO SIGNATURES: + # https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm#Public_key_recovery + # https://github.com/FlorianPicca/JWT-Key-Recovery + PUBKEY = KEY.public_key().export_key(format='OpenSSH').encode() + + # ---- CLIENT SIDE ----- + # without knowing the private key, a valid token can be constructed + # YIKES!! + + b64 = lambda x:base64.urlsafe_b64encode(x).replace(b'=',b'') + payload = b64(b'{"alg":"HS256"}') + b'.' + b64(b'{"pwned":true}') + hasher = HMAC.new(PUBKEY, digestmod=SHA256) + hasher.update(payload) + evil_token = payload + b'.' + b64(hasher.digest()) + + # ---- SERVER SIDE ----- + # verify and decode the token using the public key, as is custom + # algorithm field is left unspecified + # but the library will happily still verify without warning, trusting the user-controlled alg field of the token header + with pytest.raises(JWKError): + data = jwt.decode(evil_token, PUBKEY, algorithms=ALGORITHMS.HS256) + + with pytest.raises(JWTError, match='.*required.*"algorithms".*'): + data = jwt.decode(evil_token, PUBKEY)