diff --git a/src/saml2/aes.py b/src/saml2/aes.py index ee6a944a6..0aa9c35ec 100644 --- a/src/saml2/aes.py +++ b/src/saml2/aes.py @@ -11,36 +11,27 @@ POSTFIX_MODE = { 'cbc': modes.CBC, 'cfb': modes.CFB, - 'ecb': modes.ECB, } AES_BLOCK_SIZE = int(algorithms.AES.block_size / 8) class AESCipher(object): - def __init__(self, key, iv=None): + def __init__(self, key): """ :param key: The encryption key - :param iv: Init vector :return: AESCipher instance """ self.key = key - self.iv = iv - def build_cipher(self, iv=None, alg='aes_128_cbc'): + def build_cipher(self, alg='aes_128_cbc'): """ - :param iv: init vector :param alg: cipher algorithm :return: A Cipher instance """ typ, bits, cmode = alg.lower().split('_') bits = int(bits) - - if not iv: - if self.iv: - iv = self.iv - else: - iv = os.urandom(AES_BLOCK_SIZE) + iv = os.urandom(AES_BLOCK_SIZE) if len(iv) != AES_BLOCK_SIZE: raise Exception('Wrong iv size: {}'.format(len(iv))) @@ -63,11 +54,10 @@ def build_cipher(self, iv=None, alg='aes_128_cbc'): return cipher, iv - def encrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7', - b64enc=True, block_size=AES_BLOCK_SIZE): + def encrypt(self, msg, alg='aes_128_cbc', padding='PKCS#7', b64enc=True, + block_size=AES_BLOCK_SIZE): """ :param key: The encryption key - :param iv: init vector :param msg: Message to be encrypted :param padding: Which padding that should be used :param b64enc: Whether the result should be base64encoded @@ -87,7 +77,7 @@ def encrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7', c = chr(plen).encode() msg += c * plen - cipher, iv = self.build_cipher(iv, alg) + cipher, iv = self.build_cipher(alg) encryptor = cipher.encryptor() cmsg = iv + encryptor.update(msg) + encryptor.finalize() @@ -98,20 +88,15 @@ def encrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7', return enc_msg - def decrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7', - b64dec=True): + def decrypt(self, msg, alg='aes_128_cbc', padding='PKCS#7', b64dec=True): """ :param key: The encryption key - :param iv: init vector :param msg: Base64 encoded message to be decrypted :return: The decrypted message """ data = b64decode(msg) if b64dec else msg - _iv = data[:AES_BLOCK_SIZE] - if iv: - assert iv == _iv - cipher, iv = self.build_cipher(iv, alg=alg) + cipher, iv = self.build_cipher(alg=alg) decryptor = cipher.decryptor() res = decryptor.update(data)[AES_BLOCK_SIZE:] + decryptor.finalize() if padding in ['PKCS#5', 'PKCS#7']: @@ -122,20 +107,19 @@ def decrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7', def run_test(): key = b'1234523451234545' # 16 byte key - iv = os.urandom(AES_BLOCK_SIZE) # Iff padded, the message doesn't have to be multiple of 16 in length original_msg = b'ToBeOrNotTobe W.S.' aes = AESCipher(key) - encrypted_msg = aes.encrypt(original_msg, iv) - decrypted_msg = aes.decrypt(encrypted_msg, iv) + encrypted_msg = aes.encrypt(original_msg) + decrypted_msg = aes.decrypt(encrypted_msg) assert decrypted_msg == original_msg encrypted_msg = aes.encrypt(original_msg) decrypted_msg = aes.decrypt(encrypted_msg) assert decrypted_msg == original_msg - aes = AESCipher(key, iv) + aes = AESCipher(key) encrypted_msg = aes.encrypt(original_msg) decrypted_msg = aes.decrypt(encrypted_msg) assert decrypted_msg == original_msg diff --git a/src/saml2/authn.py b/src/saml2/authn.py index 32f91247e..049622e7c 100644 --- a/src/saml2/authn.py +++ b/src/saml2/authn.py @@ -120,7 +120,7 @@ def __init__(self, srv, mako_template, template_lookup, pwd, return_to): self.return_to = return_to self.active = {} self.query_param = "upm_answer" - self.aes = AESCipher(self.srv.symkey.encode(), srv.iv) + self.aes = AESCipher(self.srv.symkey.encode()) def __call__(self, cookie=None, policy_url=None, logo_url=None, query="", **kwargs): diff --git a/src/saml2/server.py b/src/saml2/server.py index 0e7e04033..0a2943f21 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -83,12 +83,9 @@ def __init__(self, config_file="", config=None, cache=None, stype="idp", self.init_config(stype) self.cache = cache self.ticket = {} - # self.session_db = self.choose_session_storage() - # Needed for self.symkey = symkey self.seed = rndstr() - self.iv = os.urandom(16) self.lock = threading.Lock() def getvalid_certificate_str(self):