Skip to content

Commit 44d6fd4

Browse files
committed
implement deterministic implicit rejection for RSA decryption
1 parent 0cfb77d commit 44d6fd4

File tree

2 files changed

+1391
-15
lines changed

2 files changed

+1391
-15
lines changed

tlslite/utils/rsakey.py

Lines changed: 157 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from . import tlshashlib as hashlib
88
from ..errors import MaskTooLongError, MessageTooLongError, EncodingError, \
99
InvalidSignature, UnknownRSAType
10+
from .constanttime import ct_isnonzero_u32, ct_neq_u32, ct_lsb_prop_u8, \
11+
ct_lsb_prop_u16, ct_lt_u32
1012

1113

1214
class RSAKey(object):
@@ -45,6 +47,7 @@ def __init__(self, n=0, e=0, key_type="rsa"):
4547
self.e = e
4648
# pylint: enable=invalid-name
4749
self.key_type = key_type
50+
self._key_hash = None
4851
raise NotImplementedError()
4952

5053
def __len__(self):
@@ -389,38 +392,177 @@ def encrypt(self, bytes):
389392
paddedBytes = self._addPKCS1Padding(bytes, 2)
390393
return self._raw_public_key_op_bytes(paddedBytes)
391394

395+
def _dec_prf(self, key, label, out_len):
396+
"""PRF for deterministic implicit rejection in the RSA decryption.
397+
398+
:param bytes key: key to use for derivation
399+
:param bytes label: name of the keystream generated
400+
:param int out_len: length of output, in bits
401+
:rtype: bytes
402+
:returns: a random bytestring
403+
"""
404+
out = bytearray()
405+
406+
if out_len % 8 != 0:
407+
raise ValueError("only multiples of 8 supported as output size")
408+
409+
iterator = 0
410+
while len(out) < out_len // 8:
411+
out += secureHMAC(
412+
key,
413+
numberToByteArray(iterator, 2) + label +
414+
numberToByteArray(out_len, 2),
415+
"sha256")
416+
iterator += 1
417+
418+
return out[:out_len//8]
419+
392420
def decrypt(self, encBytes):
393421
"""Decrypt the passed-in bytes.
394422
395423
This requires the key to have a private component. It performs
396-
PKCS1 decryption of the passed-in data.
424+
PKCS#1 v1.5 decryption operation of the passed-in data.
425+
426+
Note: as a workaround against Bleichenbacher-like attacks, it will
427+
return a deterministically selected random message in case the padding
428+
checks failed. It returns an error (None) only in case the ciphertext
429+
is of incorrect length or encodes an integer bigger than the modulus
430+
of the key (i.e. it's publically invalid).
397431
398432
:type encBytes: bytes-like object
399433
:param encBytes: The value which will be decrypted.
400434
401435
:rtype: bytearray or None
402-
:returns: A PKCS1 decryption of the passed-in data or None if
403-
the data is not properly formatted.
436+
:returns: A PKCS#1 v1.5 decryption of the passed-in data or None if
437+
the provided data is not properly formatted. Note: encrypting
438+
an empty string is correct, so it may return an empty bytearray
439+
for some ciphertexts.
404440
"""
405441
if not self.hasPrivateKey():
406442
raise AssertionError()
407443
if self.key_type != "rsa":
408444
raise ValueError("Decryption requires RSA key, \"{0}\" present"
409445
.format(self.key_type))
410446
try:
411-
decBytes = self._raw_private_key_op_bytes(encBytes)
447+
dec_bytes = self._raw_private_key_op_bytes(encBytes)
412448
except ValueError:
449+
# _raw_private_key_op_bytes fails only when encBytes >= self.n,
450+
# or when len(encBytes) != numBytes(self.n) and that's public
451+
# information, so we don't have to handle it
452+
# in sidechannel secure way
413453
return None
414-
#Check first two bytes
415-
if decBytes[0] != 0 or decBytes[1] != 2:
416-
return None
417-
#Scan through for zero separator
418-
for x in range(1, len(decBytes)-1):
419-
if decBytes[x]== 0:
420-
break
421-
else:
422-
return None
423-
return decBytes[x+1:] #Return everything after the separator
454+
455+
###################
456+
# here be dragons #
457+
###################
458+
# While the code is written as-if it was side-channel secure, in
459+
# practice, because of cPython implementation details IT IS NOT
460+
# see:
461+
# https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/
462+
463+
n = self.n
464+
465+
# maximum length we can return is reduced by the mandatory prefix:
466+
# (0x00 0x02), 8 bytes of padding, so this is the position of the
467+
# null separator byte, as counted from the last position
468+
max_sep_offset = numBytes(n) - 10
469+
470+
# the private exponent (d) doesn't change so `_key_hash` doesn't
471+
# change, calculate it only once
472+
if not hasattr(self, '_key_hash') or not self._key_hash:
473+
self._key_hash = secureHash(numberToByteArray(self.d, numBytes(n)),
474+
"sha256")
475+
476+
kdk = secureHMAC(self._key_hash, encBytes, "sha256")
477+
478+
# we need 128 2-byte numbers, encoded as the number of bits
479+
length_randoms = self._dec_prf(kdk, b"length", 128 * 2 * 8)
480+
481+
message_random = self._dec_prf(kdk, b"message", numBytes(n) * 8)
482+
483+
# select the last length that's not too large to return
484+
synth_length = 0
485+
length_rand_iter = iter(length_randoms)
486+
length_mask = (1 << numBits(max_sep_offset)) - 1
487+
for high, low in zip(length_rand_iter, length_rand_iter):
488+
# interpret the two bytes from the PRF output as 16-bit big-endian
489+
# integer
490+
len_candidate = (high << 8) + low
491+
len_candidate &= length_mask
492+
# equivalent to:
493+
# if len_candidate < max_sep_offset:
494+
# synth_length = len_candidate
495+
mask = ct_lt_u32(len_candidate, max_sep_offset)
496+
mask = ct_lsb_prop_u16(mask)
497+
synth_length = synth_length & (0xffff ^ mask) \
498+
| len_candidate & mask
499+
500+
synth_msg_start = numBytes(n) - synth_length
501+
502+
error_detected = 0
503+
504+
# enumerate over all decrypted bytes
505+
em_bytes = enumerate(dec_bytes)
506+
# first check if first two bytes specify PKCS#1 v1.5 encryption padding
507+
_, val = next(em_bytes)
508+
error_detected |= ct_isnonzero_u32(val)
509+
_, val = next(em_bytes)
510+
error_detected |= ct_neq_u32(val, 0x02)
511+
# then look for for the null separator byte among the padding bytes
512+
# but inspect all decrypted bytes, even if we already find the
513+
# separator earlier
514+
msg_start = 0
515+
for pos, val in em_bytes:
516+
# padding must be at least 8 bytes long, fail if any of the first
517+
# 8 bytes of it are zero
518+
# equivalent to:
519+
# if pos < 10 and not val:
520+
# error_detected = 0x01
521+
error_detected |= ct_lt_u32(pos, 10) & (1 ^ ct_isnonzero_u32(val))
522+
523+
# update the msg_start only once; when it's 0
524+
# (pos+1) because we want to skip the null separator
525+
# equivalent to:
526+
# if pos >= 10 and not msg_start and not val:
527+
# msg_start = pos+1
528+
mask = (1 ^ ct_lt_u32(pos, 10)) & (1 ^ ct_isnonzero_u32(val)) \
529+
& (1 ^ ct_isnonzero_u32(msg_start))
530+
mask = ct_lsb_prop_u16(mask)
531+
msg_start = msg_start & (0xffff ^ mask) | (pos+1) & mask
532+
533+
# if separator wasn't found, it's an error
534+
# equivalent to:
535+
# if not msg_start:
536+
# error_detected = 0x01
537+
error_detected |= 1 ^ ct_isnonzero_u32(msg_start)
538+
539+
# equivalent to:
540+
# if error_detected:
541+
# ret_msg_start = synth_msg_start
542+
# else:
543+
# ret_msg_start = msg_start
544+
mask = ct_lsb_prop_u16(error_detected)
545+
ret_msg_start = msg_start & (0xffff ^ mask) | synth_msg_start & mask
546+
547+
# as at this point the length doesn't leak the information if the
548+
# padding was correct or not, we don't have to worry about the
549+
# length of the returned value (and thus the size of the buffer we
550+
# pass to the caller); but we still need to read both buffers
551+
# to ensure that the memory access patern is preserved (that both
552+
# buffers are accessed, not just the one we return)
553+
554+
# equivalent to:
555+
# if error_detected:
556+
# return message_random[ret_msg_start:]
557+
# else:
558+
# return dec_bytes[ret_msg_start:]
559+
mask = ct_lsb_prop_u8(error_detected)
560+
not_mask = 0xff ^ mask
561+
ret = bytearray(
562+
x & not_mask | y & mask for x, y in
563+
zip(dec_bytes[ret_msg_start:], message_random[ret_msg_start:]))
564+
565+
return ret
424566

425567
def _rawPrivateKeyOp(self, message):
426568
raise NotImplementedError()
@@ -443,7 +585,7 @@ def _raw_public_key_op_bytes(self, ciphertext):
443585
if len(ciphertext) != numBytes(n):
444586
raise ValueError("Message has incorrect length for the key size")
445587
c_int = bytesToNumber(ciphertext)
446-
if c_int > n:
588+
if c_int >= n:
447589
raise ValueError("Provided message value exceeds modulus")
448590
enc_int = self._rawPublicKeyOp(c_int)
449591
return numberToByteArray(enc_int, numBytes(n))

0 commit comments

Comments
 (0)