77from . import tlshashlib as hashlib
88from ..errors import MaskTooLongError , MessageTooLongError , EncodingError , \
99 InvalidSignature , UnknownRSAType
10+ from .constanttime import ct_isnonzero_u32 , ct_neq_u32 , ct_lsb_prop_u8 , \
11+ ct_lsb_prop_u16 , ct_lt_u32
1012
1113
1214class RSAKey (object ):
@@ -45,6 +47,7 @@ def __init__(self, n=0, e=0, key_type="rsa"):
4547 self .e = e
4648 # pylint: enable=invalid-name
4749 self .key_type = key_type
50+ self ._key_hash = None
4851 raise NotImplementedError ()
4952
5053 def __len__ (self ):
@@ -389,38 +392,177 @@ def encrypt(self, bytes):
389392 paddedBytes = self ._addPKCS1Padding (bytes , 2 )
390393 return self ._raw_public_key_op_bytes (paddedBytes )
391394
395+ def _dec_prf (self , key , label , out_len ):
396+ """PRF for deterministic implicit rejection in the RSA decryption.
397+
398+ :param bytes key: key to use for derivation
399+ :param bytes label: name of the keystream generated
400+ :param int out_len: length of output, in bits
401+ :rtype: bytes
402+ :returns: a random bytestring
403+ """
404+ out = bytearray ()
405+
406+ if out_len % 8 != 0 :
407+ raise ValueError ("only multiples of 8 supported as output size" )
408+
409+ iterator = 0
410+ while len (out ) < out_len // 8 :
411+ out += secureHMAC (
412+ key ,
413+ numberToByteArray (iterator , 2 ) + label +
414+ numberToByteArray (out_len , 2 ),
415+ "sha256" )
416+ iterator += 1
417+
418+ return out [:out_len // 8 ]
419+
392420 def decrypt (self , encBytes ):
393421 """Decrypt the passed-in bytes.
394422
395423 This requires the key to have a private component. It performs
396- PKCS1 decryption of the passed-in data.
424+ PKCS#1 v1.5 decryption operation of the passed-in data.
425+
426+ Note: as a workaround against Bleichenbacher-like attacks, it will
427+ return a deterministically selected random message in case the padding
428+ checks failed. It returns an error (None) only in case the ciphertext
429+ is of incorrect length or encodes an integer bigger than the modulus
430+ of the key (i.e. it's publically invalid).
397431
398432 :type encBytes: bytes-like object
399433 :param encBytes: The value which will be decrypted.
400434
401435 :rtype: bytearray or None
402- :returns: A PKCS1 decryption of the passed-in data or None if
403- the data is not properly formatted.
436+ :returns: A PKCS#1 v1.5 decryption of the passed-in data or None if
437+ the provided data is not properly formatted. Note: encrypting
438+ an empty string is correct, so it may return an empty bytearray
439+ for some ciphertexts.
404440 """
405441 if not self .hasPrivateKey ():
406442 raise AssertionError ()
407443 if self .key_type != "rsa" :
408444 raise ValueError ("Decryption requires RSA key, \" {0}\" present"
409445 .format (self .key_type ))
410446 try :
411- decBytes = self ._raw_private_key_op_bytes (encBytes )
447+ dec_bytes = self ._raw_private_key_op_bytes (encBytes )
412448 except ValueError :
449+ # _raw_private_key_op_bytes fails only when encBytes >= self.n,
450+ # or when len(encBytes) != numBytes(self.n) and that's public
451+ # information, so we don't have to handle it
452+ # in sidechannel secure way
413453 return None
414- #Check first two bytes
415- if decBytes [0 ] != 0 or decBytes [1 ] != 2 :
416- return None
417- #Scan through for zero separator
418- for x in range (1 , len (decBytes )- 1 ):
419- if decBytes [x ]== 0 :
420- break
421- else :
422- return None
423- return decBytes [x + 1 :] #Return everything after the separator
454+
455+ ###################
456+ # here be dragons #
457+ ###################
458+ # While the code is written as-if it was side-channel secure, in
459+ # practice, because of cPython implementation details IT IS NOT
460+ # see:
461+ # https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/
462+
463+ n = self .n
464+
465+ # maximum length we can return is reduced by the mandatory prefix:
466+ # (0x00 0x02), 8 bytes of padding, so this is the position of the
467+ # null separator byte, as counted from the last position
468+ max_sep_offset = numBytes (n ) - 10
469+
470+ # the private exponent (d) doesn't change so `_key_hash` doesn't
471+ # change, calculate it only once
472+ if not hasattr (self , '_key_hash' ) or not self ._key_hash :
473+ self ._key_hash = secureHash (numberToByteArray (self .d , numBytes (n )),
474+ "sha256" )
475+
476+ kdk = secureHMAC (self ._key_hash , encBytes , "sha256" )
477+
478+ # we need 128 2-byte numbers, encoded as the number of bits
479+ length_randoms = self ._dec_prf (kdk , b"length" , 128 * 2 * 8 )
480+
481+ message_random = self ._dec_prf (kdk , b"message" , numBytes (n ) * 8 )
482+
483+ # select the last length that's not too large to return
484+ synth_length = 0
485+ length_rand_iter = iter (length_randoms )
486+ length_mask = (1 << numBits (max_sep_offset )) - 1
487+ for high , low in zip (length_rand_iter , length_rand_iter ):
488+ # interpret the two bytes from the PRF output as 16-bit big-endian
489+ # integer
490+ len_candidate = (high << 8 ) + low
491+ len_candidate &= length_mask
492+ # equivalent to:
493+ # if len_candidate < max_sep_offset:
494+ # synth_length = len_candidate
495+ mask = ct_lt_u32 (len_candidate , max_sep_offset )
496+ mask = ct_lsb_prop_u16 (mask )
497+ synth_length = synth_length & (0xffff ^ mask ) \
498+ | len_candidate & mask
499+
500+ synth_msg_start = numBytes (n ) - synth_length
501+
502+ error_detected = 0
503+
504+ # enumerate over all decrypted bytes
505+ em_bytes = enumerate (dec_bytes )
506+ # first check if first two bytes specify PKCS#1 v1.5 encryption padding
507+ _ , val = next (em_bytes )
508+ error_detected |= ct_isnonzero_u32 (val )
509+ _ , val = next (em_bytes )
510+ error_detected |= ct_neq_u32 (val , 0x02 )
511+ # then look for for the null separator byte among the padding bytes
512+ # but inspect all decrypted bytes, even if we already find the
513+ # separator earlier
514+ msg_start = 0
515+ for pos , val in em_bytes :
516+ # padding must be at least 8 bytes long, fail if any of the first
517+ # 8 bytes of it are zero
518+ # equivalent to:
519+ # if pos < 10 and not val:
520+ # error_detected = 0x01
521+ error_detected |= ct_lt_u32 (pos , 10 ) & (1 ^ ct_isnonzero_u32 (val ))
522+
523+ # update the msg_start only once; when it's 0
524+ # (pos+1) because we want to skip the null separator
525+ # equivalent to:
526+ # if pos >= 10 and not msg_start and not val:
527+ # msg_start = pos+1
528+ mask = (1 ^ ct_lt_u32 (pos , 10 )) & (1 ^ ct_isnonzero_u32 (val )) \
529+ & (1 ^ ct_isnonzero_u32 (msg_start ))
530+ mask = ct_lsb_prop_u16 (mask )
531+ msg_start = msg_start & (0xffff ^ mask ) | (pos + 1 ) & mask
532+
533+ # if separator wasn't found, it's an error
534+ # equivalent to:
535+ # if not msg_start:
536+ # error_detected = 0x01
537+ error_detected |= 1 ^ ct_isnonzero_u32 (msg_start )
538+
539+ # equivalent to:
540+ # if error_detected:
541+ # ret_msg_start = synth_msg_start
542+ # else:
543+ # ret_msg_start = msg_start
544+ mask = ct_lsb_prop_u16 (error_detected )
545+ ret_msg_start = msg_start & (0xffff ^ mask ) | synth_msg_start & mask
546+
547+ # as at this point the length doesn't leak the information if the
548+ # padding was correct or not, we don't have to worry about the
549+ # length of the returned value (and thus the size of the buffer we
550+ # pass to the caller); but we still need to read both buffers
551+ # to ensure that the memory access patern is preserved (that both
552+ # buffers are accessed, not just the one we return)
553+
554+ # equivalent to:
555+ # if error_detected:
556+ # return message_random[ret_msg_start:]
557+ # else:
558+ # return dec_bytes[ret_msg_start:]
559+ mask = ct_lsb_prop_u8 (error_detected )
560+ not_mask = 0xff ^ mask
561+ ret = bytearray (
562+ x & not_mask | y & mask for x , y in
563+ zip (dec_bytes [ret_msg_start :], message_random [ret_msg_start :]))
564+
565+ return ret
424566
425567 def _rawPrivateKeyOp (self , message ):
426568 raise NotImplementedError ()
@@ -443,7 +585,7 @@ def _raw_public_key_op_bytes(self, ciphertext):
443585 if len (ciphertext ) != numBytes (n ):
444586 raise ValueError ("Message has incorrect length for the key size" )
445587 c_int = bytesToNumber (ciphertext )
446- if c_int > n :
588+ if c_int >= n :
447589 raise ValueError ("Provided message value exceeds modulus" )
448590 enc_int = self ._rawPublicKeyOp (c_int )
449591 return numberToByteArray (enc_int , numBytes (n ))
0 commit comments