Skip to content

Commit

Permalink
support for SEC/X9.62 formatted keys
Browse files Browse the repository at this point in the history
Adds support for encoding and decoding verifying keys in format
specified in SEC 1 or in X9.62. Specifically the uncompressed point
encoding and the compressed point encoding
  • Loading branch information
tomato42 committed Oct 1, 2019
1 parent bcf6afe commit d47a238
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 19 deletions.
82 changes: 73 additions & 9 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from . import ecdsa
from . import der
from . import rfc6979
from . import ellipticcurve
from .curves import NIST192p, find_curve
from .numbertheory import square_root_mod_prime, SquareRootError
from .ecdsa import RSZeroError
from .util import string_to_number, number_to_string, randrange
from .util import sigencode_string, sigdecode_string
Expand All @@ -23,6 +25,10 @@ class BadDigestError(Exception):
pass


class MalformedPointError(AssertionError):
pass


class VerifyingKey:
def __init__(self, _error__please_use_generate=None):
if not _error__please_use_generate:
Expand All @@ -38,9 +44,8 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
self.pubkey.order = curve.order
return self

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
@staticmethod
def _from_raw_encoding(string, curve, validate_point):
order = curve.order
assert (len(string) == curve.verifying_key_length), \
(len(string), curve.verifying_key_length)
Expand All @@ -52,8 +57,50 @@ def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
y = string_to_number(ys)
if validate_point:
assert ecdsa.point_is_valid(curve.generator, x, y)
from . import ellipticcurve
point = ellipticcurve.Point(curve.curve, x, y, order)
return ellipticcurve.Point(curve.curve, x, y, order)

@staticmethod
def _from_compressed(string, curve, validate_point):
if string[:1] not in (b('\x02'), b('\x03')):
raise MalformedPointError("Malformed compressed point encoding")

is_even = string[:1] == b('\x02')
x = string_to_number(string[1:])
order = curve.order
p = curve.curve.p()
alpha = (pow(x, 3, p) + (curve.curve.a() * x) + curve.curve.b()) % p
try:
beta = square_root_mod_prime(alpha, p)
except SquareRootError as e:
raise MalformedPointError(
"Encoding does not correspond to a point on curve", e)
if is_even == bool(beta & 1):
y = p - beta
else:
y = beta
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on curve")
return ellipticcurve.Point(curve.curve, x, y, order)

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
sig_len = len(string)
if sig_len == curve.verifying_key_length:
point = klass._from_raw_encoding(string, curve, validate_point)
elif sig_len == curve.verifying_key_length + 1:
if string[:1] != b('\x04'):
raise MalformedPointError(
"Invalid uncompressed encoding of the public point")
point = klass._from_raw_encoding(string[1:], curve, validate_point)
elif sig_len == curve.baselen + 1:
point = klass._from_compressed(string, curve, validate_point)
else:
raise MalformedPointError(
"Length of string does not match lengths of "
"any of the supported encodings of {0} "
"curve.".format(curve.name))

return klass.from_public_point(point, curve, hashfunc)

@classmethod
Expand Down Expand Up @@ -110,15 +157,32 @@ def from_public_key_recovery_with_digest(klass, signature, digest, curve, hashfu
verifying_keys = [klass.from_public_point(pk.point, curve, hashfunc) for pk in pks]
return verifying_keys

def to_string(self):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
def _raw_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
y_str = number_to_string(self.pubkey.point.y(), order)
return x_str + y_str

def _compressed_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
if self.pubkey.point.y() & 1:
return b('\x03') + x_str
else:
return b('\x02') + x_str

def to_string(self, encoding="raw"):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
assert encoding in ("raw", "uncompressed", "compressed")
if encoding == "raw":
return self._raw_encode()
elif encoding == "uncompressed":
return b('\x04') + self._raw_encode()
else:
return self._compressed_encode()

def to_pem(self):
return der.topem(self.to_der(), "PUBLIC KEY")

Expand Down
23 changes: 16 additions & 7 deletions src/ecdsa/numbertheory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

from __future__ import division

from six import integer_types
from six import integer_types, PY3
from six.moves import reduce
try:
xrange
except NameError:
xrange = range

import math

Expand Down Expand Up @@ -62,7 +66,7 @@ def polynomial_reduce_mod(poly, polymod, p):

while len(poly) >= len(polymod):
if poly[-1] != 0:
for i in range(2, len(polymod) + 1):
for i in xrange(2, len(polymod) + 1):
poly[-i] = (poly[-i] - poly[-1] * polymod[-i]) % p
poly = poly[0:-1]

Expand All @@ -86,8 +90,8 @@ def polynomial_multiply_mod(m1, m2, polymod, p):

# Add together all the cross-terms:

for i in range(len(m1)):
for j in range(len(m2)):
for i in xrange(len(m1)):
for j in xrange(len(m2)):
prod[i + j] = (prod[i + j] + m1[i] * m2[j]) % p

return polynomial_reduce_mod(prod, polymod, p)
Expand Down Expand Up @@ -187,7 +191,12 @@ def square_root_mod_prime(a, p):
return (2 * a * modular_exp(4 * a, (p - 5) // 8, p)) % p
raise RuntimeError("Shouldn't get here.")

for b in range(2, p):
if PY3:
range_top = p
else:
# xrange on python2 can take integers representable as C long only
range_top = min(0x7fffffff, p)
for b in xrange(2, range_top):
if jacobi(b * b - 4 * a, p) == -1:
f = (a, -b, 1)
ff = polynomial_exp_mod((0, 1), (p + 1) // 2, f, p)
Expand Down Expand Up @@ -355,7 +364,7 @@ def carmichael_of_factorized(f_list):
return 1

result = carmichael_of_ppower(f_list[0])
for i in range(1, len(f_list)):
for i in xrange(1, len(f_list)):
result = lcm(result, carmichael_of_ppower(f_list[i]))

return result
Expand Down Expand Up @@ -477,7 +486,7 @@ def is_prime(n):
while (r % 2) == 0:
s = s + 1
r = r // 2
for i in range(t):
for i in xrange(t):
a = smallprimes[i]
y = modular_exp(a, r, n)
if y != 1 and y != n - 1:
Expand Down
104 changes: 101 additions & 3 deletions src/ecdsa/test_pyecdsa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import with_statement, division

import unittest
try:
import unittest2 as unittest
except ImportError:
import unittest
import os
import time
import shutil
Expand All @@ -11,12 +14,14 @@

from six import b, print_, binary_type
from .keys import SigningKey, VerifyingKey
from .keys import BadSignatureError
from .keys import BadSignatureError, MalformedPointError
from . import util
from .util import sigencode_der, sigencode_strings
from .util import sigdecode_der, sigdecode_strings
from .util import number_to_string
from .curves import Curve, UnknownCurveError
from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1
from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, \
SECP256k1, curves
from .ellipticcurve import Point
from . import der
from . import rfc6979
Expand Down Expand Up @@ -367,6 +372,99 @@ def test_public_key_recovery_with_custom_hash(self):
self.assertTrue(vk.pubkey.point in
[recovered_vk.pubkey.point for recovered_vk in recovered_vks])

def test_encoding(self):
sk = SigningKey.from_secret_exponent(123456789)
vk = sk.verifying_key

exp = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')
self.assertEqual(vk.to_string(), exp)
self.assertEqual(vk.to_string('uncompressed'), b('\x04') + exp)
self.assertEqual(vk.to_string('compressed'), b('\x02') + exp[:24])

def test_decoding(self):
sk = SigningKey.from_secret_exponent(123456789)
vk = sk.verifying_key

enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

from_raw = VerifyingKey.from_string(enc)
self.assertEqual(from_raw.pubkey.point, vk.pubkey.point)

from_uncompressed = VerifyingKey.from_string(b('\x04') + enc)
self.assertEqual(from_uncompressed.pubkey.point, vk.pubkey.point)

from_compressed = VerifyingKey.from_string(b('\x02') + enc[:24])
self.assertEqual(from_compressed.pubkey.point, vk.pubkey.point)

def test_decoding_with_malformed_uncompressed(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)

def test_decoding_with_malformed_compressed(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x01') + enc[:24])

def test_decoding_with_point_at_infinity(self):
# decoding it is unsupported, as it's not necessary to encode it
with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00'))

def test_not_lying_on_curve(self):
enc = number_to_string(NIST192p.order, NIST192p.order+1)

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)


@pytest.mark.parametrize("val,even",
[(i, j) for i in range(256) for j in [True, False]])
def test_VerifyingKey_decode_with_small_values(val, even):
enc = number_to_string(val, NIST192p.order)

if even:
enc = b('\x02') + enc
else:
enc = b('\x03') + enc

# small values can both be actual valid public keys and not, verify that
# only expected exceptions are raised if they are not
try:
vk = VerifyingKey.from_string(enc)
assert isinstance(vk, VerifyingKey)
except MalformedPointError:
assert True


params = []
for curve in curves:
for enc in ["raw", "uncompressed", "compressed"]:
params.append(pytest.param(curve, enc, id="{0}-{1}".format(
curve.name, enc)))


@pytest.mark.parametrize("curve,encoding", params)
def test_VerifyingKey_encode_decode(curve, encoding):
sk = SigningKey.generate(curve=curve)
vk = sk.verifying_key

encoded = vk.to_string(encoding)

from_enc = VerifyingKey.from_string(encoded, curve=curve)

assert vk.pubkey.point == from_enc.pubkey.point


class OpenSSL(unittest.TestCase):
# test interoperability with OpenSSL tools. Note that openssl's ECDSA
Expand Down

0 comments on commit d47a238

Please # to comment.