diff --git a/.gitignore b/.gitignore index c09abac..8d946db 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc *.pyo *.egg-info +*.coverage __pycache__ bin build diff --git a/pyproject.toml b/pyproject.toml index bd50594..4039df2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ version = {attr = "otpauth.__version__"} where = ["src"] [tool.pytest.ini_options] -pythonpath = ["src", "."] +pythonpath = ["src"] testpaths = ["tests"] filterwarnings = ["error"] diff --git a/src/otpauth/core.py b/src/otpauth/core.py index 1a8362a..904cb49 100644 --- a/src/otpauth/core.py +++ b/src/otpauth/core.py @@ -1,4 +1,5 @@ import base64 +import typing as t from urllib.parse import quote from abc import ABCMeta, abstractmethod @@ -16,28 +17,34 @@ def __init__(self, secret: bytes, digit: int = 6, algorithm: str = "SHA1"): self._b32_secret = None @property - def b32_secret(self) -> bytes: + def b32_secret(self) -> str: if self._b32_secret: return self._b32_secret secret = base64.b32encode(self.secret) - self._b32_secret = secret.rstrip(b'=') + self._b32_secret = secret.rstrip(b'=').decode("ascii") return self._b32_secret @classmethod - def from_b32encode(cls, secret: bytes): + def from_b32encode(cls, secret: t.AnyStr): + if isinstance(secret, str): + secret = secret.encode("utf-8") + b32_secret = secret.rstrip(b'=') # add padding back secret += b'=' * (-len(secret) % 8) raw_secret = base64.b32decode(secret) + obj = cls(raw_secret) - obj._b32_secret = b32_secret + obj._b32_secret = b32_secret.decode("ascii") return obj def _get_base_uri(self, label: str, issuer: str) -> str: - label = quote(label, safe="/@") - return f"otpauth://hotp/{label}?secret={self.b32_secret}&issuer={issuer}&algorithm={self.algorithm}&digits={self.digit}" + label = quote(label, safe="/@:") + issuer = quote(issuer, safe="") + _type = self.TYPE.lower() + return f"otpauth://{_type}/{label}?secret={self.b32_secret}&issuer={issuer}&algorithm={self.algorithm}&digits={self.digit}" @abstractmethod def generate(self, *args, **kwargs) -> int: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_hotp.py b/tests/test_hotp.py new file mode 100644 index 0000000..2c0cf2c --- /dev/null +++ b/tests/test_hotp.py @@ -0,0 +1,46 @@ +import unittest +from otpauth import HOTP + + +class TestHOTP(unittest.TestCase): + def setUp(self): + self.hotp = HOTP(b"python") + + def test_generate(self): + value = self.hotp.generate(0) + self.assertEqual(value, 170566) + + def test_verify(self): + # due to number too long + self.assertFalse(self.hotp.verify(12345678, 0)) + + # due to not match + self.assertFalse(self.hotp.verify(12345, 0)) + + self.assertTrue(self.hotp.verify(170566, 0)) + + def test_to_uri(self): + uri = self.hotp.to_uri("Typlog:lepture.com", "Authlib", 0) + expected = "otpauth://hotp/Typlog:lepture.com?secret=OB4XI2DPNY&issuer=Authlib&algorithm=SHA1&digits=6&counter=0" + self.assertEqual(uri, expected) + + def test_from_b32encode(self): + expected = "otpauth://hotp/Typlog:lepture.com?secret=OB4XI2DPNY&issuer=Authlib&algorithm=SHA1&digits=6&counter=0" + + hotp = HOTP.from_b32encode("OB4XI2DPNY") + value = hotp.generate(0) + self.assertEqual(value, 170566) + uri = self.hotp.to_uri("Typlog:lepture.com", "Authlib", 0) + self.assertEqual(uri, expected) + + hotp = HOTP.from_b32encode("OB4XI2DPNY======") + value = hotp.generate(0) + self.assertEqual(value, 170566) + uri = self.hotp.to_uri("Typlog:lepture.com", "Authlib", 0) + self.assertEqual(uri, expected) + + hotp = HOTP.from_b32encode(b"OB4XI2DPNY======") + value = hotp.generate(0) + self.assertEqual(value, 170566) + uri = self.hotp.to_uri("Typlog:lepture.com", "Authlib", 0) + self.assertEqual(uri, expected) diff --git a/tests/test_totp.py b/tests/test_totp.py new file mode 100644 index 0000000..7224cb4 --- /dev/null +++ b/tests/test_totp.py @@ -0,0 +1,32 @@ +import unittest +import time +from otpauth import TOTP + +FIXED_TIME = 1679576495 + + +class TestTOTP(unittest.TestCase): + def setUp(self): + self.totp = TOTP(b"python") + + def test_generate(self): + value = self.totp.generate(FIXED_TIME) + self.assertEqual(value, 129815) + + def test_verify(self): + # due to number too long + self.assertFalse(self.totp.verify(12345678, FIXED_TIME)) + + # due to not match + self.assertFalse(self.totp.verify(12345, FIXED_TIME)) + + self.assertTrue(self.totp.verify(129815, FIXED_TIME)) + + def test_to_uri(self): + uri = self.totp.to_uri("Typlog:lepture.com", "Authlib") + expected = "otpauth://totp/Typlog:lepture.com?secret=OB4XI2DPNY&issuer=Authlib&algorithm=SHA1&digits=6&period=30" + self.assertEqual(uri, expected) + + def test_current_timestamp(self): + value = self.totp.generate() + self.assertTrue(self.totp.verify(value))