diff --git a/tests/tlstest.py b/tests/tlstest.py index c257169c..9ca6dbe4 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -728,6 +728,38 @@ def connect(): connection.handshakeClientCert(settings=settings) assert connection.session.serverCertChain is None assert connection.ecdhCurve is not None + assert connection.session.cipherSuite in \ + constants.CipherSuite.sha384PrfSuites + testConnClient(connection) + connection.close() + + test_no += 1 + + print("Test {0} - good PSK SHA-256 PRF".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.pskConfigs = [(b'test', b'\x00secret', 'sha256')] + connection.handshakeClientCert(settings=settings) + assert connection.session.serverCertChain is None + assert connection.ecdhCurve is not None + assert connection.session.cipherSuite in \ + constants.CipherSuite.sha256PrfSuites + testConnClient(connection) + connection.close() + + test_no += 1 + + print("Test {0} - good PSK default PRF".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.pskConfigs = [(b'test', b'\x00secret', 'sha256')] + connection.handshakeClientCert(settings=settings) + assert connection.session.serverCertChain is None + assert connection.ecdhCurve is not None + assert connection.session.cipherSuite in \ + constants.CipherSuite.sha256PrfSuites testConnClient(connection) connection.close() @@ -761,6 +793,21 @@ def connect(): test_no += 1 + print("Test {0} - bad PSK X.509 fallback".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.pskConfigs = [(b'bad identity', b'\x00secret', 'sha256')] + connection.handshakeClientCert(settings=settings) + assert connection.session.serverCertChain + assert connection.ecdhCurve is not None + assert connection.session.cipherSuite in \ + constants.CipherSuite.sha384PrfSuites + testConnClient(connection) + connection.close() + + test_no += 1 + print("Test {0} - good SRP (db)".format(test_no)) print("client {0} - waiting for synchro".format(time.time())) try: @@ -2353,6 +2400,30 @@ def connect(): test_no += 1 + print("Test {0} - good PSK SHA-256 PRF".format(test_no)) + synchro.send(b'R') + settings = HandshakeSettings() + settings.pskConfigs = [(b'test', b'\x00secret', 'sha256')] + connection = connect() + connection.handshakeServer(certChain=x509Chain, privateKey=x509Key, + settings=settings) + testConnServer(connection) + connection.close() + + test_no += 1 + + print("Test {0} - good PSK default PRF".format(test_no)) + synchro.send(b'R') + settings = HandshakeSettings() + settings.pskConfigs = [(b'test', b'\x00secret')] + connection = connect() + connection.handshakeServer(certChain=x509Chain, privateKey=x509Key, + settings=settings) + testConnServer(connection) + connection.close() + + test_no += 1 + print("Test {0} - good PSK, no DH".format(test_no)) synchro.send(b'R') settings = HandshakeSettings() @@ -2378,6 +2449,18 @@ def connect(): test_no += 1 + print("Test {0} - bad PSK X.509 fallback".format(test_no)) + synchro.send(b'R') + settings = HandshakeSettings() + settings.pskConfigs = [(b'test', b'\x00secret', 'sha256')] + connection = connect() + connection.handshakeServer(certChain=x509Chain, privateKey=x509Key, + settings=settings) + testConnServer(connection) + connection.close() + + test_no += 1 + print("Test {0} - good SRP (db)".format(test_no)) try: import logging diff --git a/tlslite/constants.py b/tlslite/constants.py index f6b59a1a..74904c49 100644 --- a/tlslite/constants.py +++ b/tlslite/constants.py @@ -1205,7 +1205,7 @@ class CipherSuite: aeadSuites.extend(chacha20Suites) aeadSuites.extend(chacha20draft00Suites) - #: TLS1.2 with SHA384 PRF + #: any with SHA384 PRF sha384PrfSuites = [] sha384PrfSuites.extend(sha384Suites) sha384PrfSuites.extend(aes256GcmSuites) @@ -1227,6 +1227,12 @@ class CipherSuite: tls12Suites.extend(sha384Suites) tls12Suites.extend(aeadSuites) + #: any that will end up using SHA256 PRF in TLS 1.2 or later + sha256PrfSuites = [] + sha256PrfSuites.extend(tls12Suites) + for i in sha384PrfSuites: + sha256PrfSuites.remove(i) + #: TLS1.3 specific ciphersuites tls13Suites = [] @@ -1280,6 +1286,23 @@ def filter_for_certificate(suites, cert_chain): includeSuites.update(CipherSuite.ecdhAnonSuites) return [s for s in suites if s in includeSuites] + @staticmethod + def filter_for_prfs(suites, prfs): + """Return a copy of suites without ciphers incompatible with the + specified prfs (sha256 or sha384)""" + includeSuites = set() + prfs = set(prfs) + if None in prfs: + prfs.update(["sha256"]) + prfs.remove(None) + assert len(prfs) <= 2, prfs + + if "sha256" in prfs: + includeSuites.update(CipherSuite.sha256PrfSuites) + if "sha384" in prfs: + includeSuites.update(CipherSuite.sha384PrfSuites) + return [s for s in suites if s in includeSuites] + @staticmethod def _filterSuites(suites, settings, version=None): if version is None: diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 352a566d..c02c540b 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -298,7 +298,7 @@ class HandshakeSettings(object): (bytearray, can be empty for TLS 1.2 and earlier), second element is the binary secret (bytearray), third is an optional parameter specifying the PRF hash to be used in TLS 1.3 (``sha256`` or - ``sha384``) + ``sha384``, with ``sha256`` being the default) :vartype ticketKeys: list(bytearray) :ivar ticketKeys: keys to be used for encrypting and decrypting session diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 809401c8..270a0723 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -3960,6 +3960,8 @@ def _server_select_certificate(self, settings, client_hello, else: client_sigalgs = [] + client_psks = client_hello.getExtension(ExtensionType.pre_shared_key) + # Get all the certificates we can offer alt_certs = ((X509CertChain(i.certificates), i.key) for vh in settings.virtual_hosts for i in vh.keys) @@ -3967,7 +3969,6 @@ def _server_select_certificate(self, settings, client_hello, for cert, key in chain([(cert_chain, private_key)], alt_certs)] for cert, key in certs: - # Check if this is the last (cert, key) pair we have to check if (cert, key) == certs[-1]: last_cert = True @@ -3977,10 +3978,23 @@ def _server_select_certificate(self, settings, client_hello, try: # Find a suitable ciphersuite based on the certificate ciphers = CipherSuite.filter_for_certificate(cipher_suites, cert) + # but if we have matching PSKs, prefer those + if settings.pskConfigs and client_psks: + client_identities = [ + i.identity for i in client_psks.identities] + psks_prfs = [i[2] if len(i) == 3 else None for i in + settings.pskConfigs if + i[0] in client_identities] + if psks_prfs: + ciphers = CipherSuite.filter_for_prfs(ciphers, + psks_prfs) for cipher in ciphers: + # select first mutually supported if cipher in client_hello.cipher_suites: break else: + # abort with context-specific alert if client indicated + # support for FFDHE groups if client_groups and \ any(i in range(256, 512) for i in client_groups) and \ any(i in CipherSuite.dhAllSuites