Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Check digest function to prevent error on OTP Generation #170

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/pyotp/hotp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def __init__(
"""
if digest is None:
digest = hashlib.sha1
elif digest in [
hashlib.md5,
hashlib.shake_128
]:
raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes")
Comment on lines +32 to +36

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I, for some reason (like auditing instrumentation), wrap md5 in a custom call, this won't catch it. Also, the list of overspecific; there might be more of such algorithms (in the future?), and we don't want to miss those.

It will thus be best to check the .digest_size attribute, as done in otp.py:generate_otp().


self.initial_count = initial_count
super().__init__(s=s, digits=digits, digest=digest, name=name, issuer=issuer)
Expand Down
7 changes: 7 additions & 0 deletions src/pyotp/otp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def __init__(
if digits > 10:
raise ValueError("digits must be no greater than 10")
self.digest = digest
if digest in [
hashlib.md5,
hashlib.shake_128
]:
raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes")
self.secret = s
self.name = name or "Secret"
self.issuer = issuer
Expand All @@ -33,6 +38,8 @@ def generate_otp(self, input: int) -> str:
if input < 0:
raise ValueError("input must be positive integer")
hasher = hmac.new(self.byte_secret(), self.int_to_bytestring(input), self.digest)
if hasher.digest_size < 18:
raise ValueError("digest size is lower than 18 bytes, which will trigger error on otp generation")
hmac_hash = bytearray(hasher.digest())
offset = hmac_hash[-1] & 0xF
code = (
Expand Down
5 changes: 5 additions & 0 deletions src/pyotp/totp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(
"""
if digest is None:
digest = hashlib.sha1
elif digest in [
hashlib.md5,
hashlib.shake_128
]:
raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes")
Comment on lines +35 to +39

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can go away if the check is one in otp.py:generate_otp(): It is not necessary, and there should only be one place where the relevant check is done.


self.interval = interval
super().__init__(s=s, digits=digits, digest=digest, name=name, issuer=issuer)
Expand Down
10 changes: 10 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,16 @@ def test_valid_window(self):
self.assertTrue(totp.verify("681610", 200, 1))
self.assertFalse(totp.verify("195979", 200, 1))

class DigestFunctionTest(unittest.TestCase):
def test_md5(self):
with self.assertRaises(ValueError) as cm:
pyotp.OTP(s="secret", digest=hashlib.md5)
self.assertEqual("selected digest function must generate digest size greater than or equals to 18 bytes", str(cm.exception))

def test_shake128(self):
with self.assertRaises(ValueError) as cm:
pyotp.OTP(s="secret", digest=hashlib.shake_128)
self.assertEqual("selected digest function must generate digest size greater than or equals to 18 bytes", str(cm.exception))

class ParseUriTest(unittest.TestCase):
def test_invalids(self):
Expand Down
Loading