49
49
class AbstractKey :
50
50
"""Abstract superclass for private and public keys."""
51
51
52
- __slots__ = ('n' , 'e' )
52
+ __slots__ = ('n' , 'e' , 'blindfac' , 'blindfac_inverse' )
53
53
54
54
def __init__ (self , n : int , e : int ) -> None :
55
55
self .n = n
56
56
self .e = e
57
57
58
+ # These will be computed properly on the first call to blind().
59
+ self .blindfac = self .blindfac_inverse = - 1
60
+
58
61
@classmethod
59
62
def _load_pkcs1_pem (cls , keyfile : bytes ) -> 'AbstractKey' :
60
63
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
@@ -145,7 +148,7 @@ def save_pkcs1(self, format: str = 'PEM') -> bytes:
145
148
method = self ._assert_format_exists (format , methods )
146
149
return method ()
147
150
148
- def blind (self , message : int , r : int ) -> int :
151
+ def blind (self , message : int ) -> int :
149
152
"""Performs blinding on the message using random number 'r'.
150
153
151
154
:param message: the message, as integer, to blind.
@@ -159,10 +162,10 @@ def blind(self, message: int, r: int) -> int:
159
162
160
163
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
161
164
"""
165
+ self ._update_blinding_factor ()
166
+ return (message * pow (self .blindfac , self .e , self .n )) % self .n
162
167
163
- return (message * pow (r , self .e , self .n )) % self .n
164
-
165
- def unblind (self , blinded : int , r : int ) -> int :
168
+ def unblind (self , blinded : int ) -> int :
166
169
"""Performs blinding on the message using random number 'r'.
167
170
168
171
:param blinded: the blinded message, as integer, to unblind.
@@ -174,8 +177,27 @@ def unblind(self, blinded: int, r: int) -> int:
174
177
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
175
178
"""
176
179
177
- return (rsa . common . inverse ( r , self .n ) * blinded ) % self .n
180
+ return (self .blindfac_inverse * blinded ) % self .n
178
181
182
+ def _initial_blinding_factor (self ) -> int :
183
+ for _ in range (1000 ):
184
+ blind_r = rsa .randnum .randint (self .n - 1 )
185
+ if rsa .prime .are_relatively_prime (self .n , blind_r ):
186
+ return blind_r
187
+ raise RuntimeError ('unable to find blinding factor' )
188
+
189
+ def _update_blinding_factor (self ):
190
+ if self .blindfac < 0 :
191
+ # Compute initial blinding factor, which is rather slow to do.
192
+ self .blindfac = self ._initial_blinding_factor ()
193
+ self .blindfac_inverse = rsa .common .inverse (self .blindfac , self .n )
194
+ else :
195
+ # Reuse previous blinding factor as per section 9 of 'A Timing
196
+ # Attack against RSA with the Chinese Remainder Theorem' by Werner
197
+ # Schindler.
198
+ # See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
199
+ self .blindfac = pow (self .blindfac , 2 , self .n )
200
+ self .blindfac_inverse = pow (self .blindfac_inverse , 2 , self .n )
179
201
180
202
class PublicKey (AbstractKey ):
181
203
"""Represents a public RSA key.
@@ -414,13 +436,6 @@ def __ne__(self, other: typing.Any) -> bool:
414
436
def __hash__ (self ) -> int :
415
437
return hash ((self .n , self .e , self .d , self .p , self .q , self .exp1 , self .exp2 , self .coef ))
416
438
417
- def _get_blinding_factor (self ) -> int :
418
- for _ in range (1000 ):
419
- blind_r = rsa .randnum .randint (self .n - 1 )
420
- if rsa .prime .are_relatively_prime (self .n , blind_r ):
421
- return blind_r
422
- raise RuntimeError ('unable to find blinding factor' )
423
-
424
439
def blinded_decrypt (self , encrypted : int ) -> int :
425
440
"""Decrypts the message using blinding to prevent side-channel attacks.
426
441
@@ -431,11 +446,9 @@ def blinded_decrypt(self, encrypted: int) -> int:
431
446
:rtype: int
432
447
"""
433
448
434
- blind_r = self ._get_blinding_factor ()
435
- blinded = self .blind (encrypted , blind_r ) # blind before decrypting
449
+ blinded = self .blind (encrypted ) # blind before decrypting
436
450
decrypted = rsa .core .decrypt_int (blinded , self .d , self .n )
437
-
438
- return self .unblind (decrypted , blind_r )
451
+ return self .unblind (decrypted )
439
452
440
453
def blinded_encrypt (self , message : int ) -> int :
441
454
"""Encrypts the message using blinding to prevent side-channel attacks.
@@ -447,10 +460,9 @@ def blinded_encrypt(self, message: int) -> int:
447
460
:rtype: int
448
461
"""
449
462
450
- blind_r = self ._get_blinding_factor ()
451
- blinded = self .blind (message , blind_r ) # blind before encrypting
463
+ blinded = self .blind (message ) # blind before encrypting
452
464
encrypted = rsa .core .encrypt_int (blinded , self .d , self .n )
453
- return self .unblind (encrypted , blind_r )
465
+ return self .unblind (encrypted )
454
466
455
467
@classmethod
456
468
def _load_pkcs1_der (cls , keyfile : bytes ) -> 'PrivateKey' :
0 commit comments