Skip to content

Commit fbefe7e

Browse files
LinuxJedidanielinux
authored andcommitted
Simplify OAEP and PSS
Makes things a little bit more like similar APIs. * Hash type is now set in constructor. * MGF is set automtically or manually with `set_mgf()` * Label defaults to empty
1 parent e34a0ec commit fbefe7e

File tree

2 files changed

+73
-30
lines changed

2 files changed

+73
-30
lines changed

tests/test_ciphers.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,14 @@ def test_chacha_enc_dec(chacha_obj):
333333
def rsa_private(vectors):
334334
return RsaPrivate(vectors[RsaPrivate].key)
335335

336+
@pytest.fixture
337+
def rsa_private_oaep(vectors):
338+
return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA)
339+
340+
@pytest.fixture
341+
def rsa_private_pss(vectors):
342+
return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA256)
343+
336344
@pytest.fixture
337345
def rsa_private_pkcs8(vectors):
338346
return RsaPrivate(vectors[RsaPrivate].pkcs8_key)
@@ -341,6 +349,14 @@ def rsa_private_pkcs8(vectors):
341349
def rsa_public(vectors):
342350
return RsaPublic(vectors[RsaPublic].key)
343351

352+
@pytest.fixture
353+
def rsa_public_oaep(vectors):
354+
return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA)
355+
356+
@pytest.fixture
357+
def rsa_public_pss(vectors):
358+
return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA256)
359+
344360
@pytest.fixture
345361
def rsa_private_pem(vectors):
346362
with open(vectors[RsaPrivate].pem, "rb") as f:
@@ -382,21 +398,21 @@ def test_rsa_encrypt_decrypt(rsa_private, rsa_public):
382398
assert 1024 / 8 == len(ciphertext) == rsa_private.output_size
383399
assert plaintext == rsa_private.decrypt(ciphertext)
384400

385-
def test_rsa_encrypt_decrypt_pad_oaep(rsa_private, rsa_public):
401+
def test_rsa_encrypt_decrypt_pad_oaep(rsa_private_oaep, rsa_public_oaep):
386402
plaintext = t2b("Everyone gets Friday off.")
387403

388404
# normal usage, encrypt with public, decrypt with private
389-
ciphertext = rsa_public.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "")
405+
ciphertext = rsa_public_oaep.encrypt_oaep(plaintext)
390406

391-
assert 1024 / 8 == len(ciphertext) == rsa_public.output_size
392-
assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "")
407+
assert 1024 / 8 == len(ciphertext) == rsa_public_oaep.output_size
408+
assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext)
393409

394410
# private object holds both private and public info, so it can also encrypt
395411
# using the known public key.
396-
ciphertext = rsa_private.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "")
412+
ciphertext = rsa_private_oaep.encrypt_oaep(plaintext)
397413

398-
assert 1024 / 8 == len(ciphertext) == rsa_private.output_size
399-
assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "")
414+
assert 1024 / 8 == len(ciphertext) == rsa_private_oaep.output_size
415+
assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext)
400416

401417

402418
def test_rsa_pkcs8_encrypt_decrypt(rsa_private_pkcs8, rsa_public):
@@ -433,21 +449,21 @@ def test_rsa_sign_verify(rsa_private, rsa_public):
433449
assert plaintext == rsa_private.verify(signature)
434450

435451
if _lib.RSA_PSS_ENABLED:
436-
def test_rsa_pss_sign_verify(rsa_private, rsa_public):
452+
def test_rsa_pss_sign_verify(rsa_private_pss, rsa_public_pss):
437453
plaintext = t2b("Everyone gets Friday off yippee.")
438454

439455
# normal usage, sign with private, verify with public
440-
signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256)
456+
signature = rsa_private_pss.sign_pss(plaintext)
441457

442-
assert 1024 / 8 == len(signature) == rsa_private.output_size
443-
assert 0 == rsa_public.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256)
458+
assert 1024 / 8 == len(signature) == rsa_private_pss.output_size
459+
assert 0 == rsa_public_pss.verify_pss(plaintext, signature)
444460

445461
# private object holds both private and public info, so it can also verify
446462
# using the known public key.
447-
signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256)
463+
signature = rsa_private_pss.sign_pss(plaintext)
448464

449-
assert 1024 / 8 == len(signature) == rsa_private.output_size
450-
assert 0 == rsa_private.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256)
465+
assert 1024 / 8 == len(signature) == rsa_private_pss.output_size
466+
assert 0 == rsa_private_pss.verify_pss(plaintext, signature)
451467

452468
def test_rsa_sign_verify_pem(rsa_private_pem, rsa_public_pem):
453469
plaintext = t2b("Everyone gets Friday off.")

wolfcrypt/ciphers.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ def _decrypt(self, destination, source):
452452
if _lib.RSA_ENABLED:
453453
class _Rsa(object): # pylint: disable=too-few-public-methods
454454
RSA_MIN_PAD_SIZE = 11
455+
_mgf = None
456+
_hash_type = None
455457

456458
def __init__(self):
457459
self.native_object = _ffi.new("RsaKey *")
@@ -473,11 +475,30 @@ def __del__(self):
473475
if self.native_object:
474476
self._delete(self.native_object)
475477

478+
def set_mgf(self, mgf):
479+
self._mgf = mgf
480+
481+
def _get_mgf(self):
482+
if self._hash_type == _lib.WC_HASH_TYPE_SHA:
483+
self._mgf = _lib.WC_MGF1SHA1
484+
elif self._hash_type == _lib.WC_HASH_TYPE_SHA224:
485+
self._mgf = _lib.WC_MGF1SHA224
486+
elif self._hash_type == _lib.WC_HASH_TYPE_SHA256:
487+
self._mgf = _lib.WC_MGF1SHA256
488+
elif self._hash_type == _lib.WC_HASH_TYPE_SHA384:
489+
self._mgf = _lib.WC_MGF1SHA384
490+
elif self._hash_type == _lib.WC_HASH_TYPE_SHA512:
491+
self._mgf = _lib.WC_MGF1SHA512
492+
else:
493+
self._mgf = _lib.WC_MGF1NONE
494+
495+
476496

477497
class RsaPublic(_Rsa):
478-
def __init__(self, key=None):
498+
def __init__(self, key=None, hash_type=None):
479499
if key != None:
480500
key = t2b(key)
501+
self._hash_type = hash_type
481502

482503
_Rsa.__init__(self)
483504

@@ -524,17 +545,18 @@ def encrypt(self, plaintext):
524545

525546
return _ffi.buffer(ciphertext)[:]
526547

527-
def encrypt_oaep(self, plaintext, hash_type, mgf, label):
548+
def encrypt_oaep(self, plaintext, label=""):
528549
plaintext = t2b(plaintext)
529550
label = t2b(label)
530551
ciphertext = _ffi.new("byte[%d]" % self.output_size)
531-
552+
if self._mgf is None:
553+
self._get_mgf()
532554
ret = _lib.wc_RsaPublicEncrypt_ex(plaintext, len(plaintext),
533555
ciphertext, self.output_size,
534556
self.native_object,
535557
self._random.native_object,
536-
_lib.WC_RSA_OAEP_PAD, hash_type,
537-
mgf, label, len(label))
558+
_lib.WC_RSA_OAEP_PAD, self._hash_type,
559+
self._mgf, label, len(label))
538560

539561
if ret != self.output_size: # pragma: no cover
540562
raise WolfCryptError("Encryption error (%d)" % ret)
@@ -563,7 +585,7 @@ def verify(self, signature):
563585
return _ffi.buffer(plaintext, ret)[:]
564586

565587
if _lib.RSA_PSS_ENABLED:
566-
def verify_pss(self, plaintext, signature, hash_type, mgf):
588+
def verify_pss(self, plaintext, signature):
567589
"""
568590
Verifies **signature**, using the public key data in the
569591
object. The signature's length must be equal to:
@@ -574,17 +596,19 @@ def verify_pss(self, plaintext, signature, hash_type, mgf):
574596
"""
575597
plaintext = t2b(plaintext)
576598
signature = t2b(signature)
599+
if self._mgf is None:
600+
self._get_mgf()
577601
verify = _ffi.new("byte[%d]" % self.output_size)
578602

579603
ret = _lib.wc_RsaPSS_Verify(signature, len(signature),
580604
verify, self.output_size,
581-
hash_type, mgf,
605+
self._hash_type, self._mgf,
582606
self.native_object)
583607

584608
if ret < 0: # pragma: no cover
585609
raise WolfCryptError("Verify error (%d)" % ret)
586610
ret = _lib.wc_RsaPSS_CheckPadding(plaintext, len(plaintext),
587-
verify, ret, hash_type)
611+
verify, ret, self._hash_type)
588612

589613
return ret
590614

@@ -613,10 +637,10 @@ def make_key(cls, size, rng=Random()):
613637

614638
return rsa
615639

616-
def __init__(self, key = None): # pylint: disable=super-init-not-called
640+
def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called
617641

618642
_Rsa.__init__(self) # pylint: disable=non-parent-init-called
619-
643+
self._hash_type = hash_type
620644
idx = _ffi.new("word32*")
621645
idx[0] = 0
622646

@@ -692,7 +716,7 @@ def decrypt(self, ciphertext):
692716

693717
return _ffi.buffer(plaintext, ret)[:]
694718

695-
def decrypt_oaep(self, ciphertext, hash_type, mgf, label):
719+
def decrypt_oaep(self, ciphertext, label=""):
696720
"""
697721
Decrypts **ciphertext**, using the private key data in the
698722
object. The ciphertext's length must be equal to:
@@ -704,11 +728,13 @@ def decrypt_oaep(self, ciphertext, hash_type, mgf, label):
704728
ciphertext = t2b(ciphertext)
705729
label = t2b(label)
706730
plaintext = _ffi.new("byte[%d]" % self.output_size)
731+
if self._mgf is None:
732+
self._get_mgf()
707733
ret = _lib.wc_RsaPrivateDecrypt_ex(ciphertext, len(ciphertext),
708734
plaintext, self.output_size,
709735
self.native_object,
710-
_lib.WC_RSA_OAEP_PAD, hash_type,
711-
mgf, label, len(label))
736+
_lib.WC_RSA_OAEP_PAD, self._hash_type,
737+
self._mgf, label, len(label))
712738

713739
if ret < 0: # pragma: no cover
714740
raise WolfCryptError("Decryption error (%d)" % ret)
@@ -738,7 +764,7 @@ def sign(self, plaintext):
738764
return _ffi.buffer(signature, self.output_size)[:]
739765

740766
if _lib.RSA_PSS_ENABLED:
741-
def sign_pss(self, plaintext, hash_type, mgf):
767+
def sign_pss(self, plaintext):
742768
"""
743769
Signs **plaintext**, using the private key data in the object.
744770
The plaintext's length must not be greater than:
@@ -749,10 +775,11 @@ def sign_pss(self, plaintext, hash_type, mgf):
749775
"""
750776
plaintext = t2b(plaintext)
751777
signature = _ffi.new("byte[%d]" % self.output_size)
752-
778+
if self._mgf is None:
779+
self._get_mgf()
753780
ret = _lib.wc_RsaPSS_Sign(plaintext, len(plaintext),
754781
signature, self.output_size,
755-
hash_type, mgf,
782+
self._hash_type, self._mgf,
756783
self.native_object,
757784
self._random.native_object)
758785

0 commit comments

Comments
 (0)