@@ -452,6 +452,8 @@ def _decrypt(self, destination, source):
452452if _lib .RSA_ENABLED :
453453 class _Rsa (object ): # pylint: disable=too-few-public-methods
454454 RSA_MIN_PAD_SIZE = 11
455+ _mgf = None
456+ _hash_type = None
455457
456458 def __init__ (self ):
457459 self .native_object = _ffi .new ("RsaKey *" )
@@ -473,11 +475,30 @@ def __del__(self):
473475 if self .native_object :
474476 self ._delete (self .native_object )
475477
478+ def set_mgf (self , mgf ):
479+ self ._mgf = mgf
480+
481+ def _get_mgf (self ):
482+ if self ._hash_type == _lib .WC_HASH_TYPE_SHA :
483+ self ._mgf = _lib .WC_MGF1SHA1
484+ elif self ._hash_type == _lib .WC_HASH_TYPE_SHA224 :
485+ self ._mgf = _lib .WC_MGF1SHA224
486+ elif self ._hash_type == _lib .WC_HASH_TYPE_SHA256 :
487+ self ._mgf = _lib .WC_MGF1SHA256
488+ elif self ._hash_type == _lib .WC_HASH_TYPE_SHA384 :
489+ self ._mgf = _lib .WC_MGF1SHA384
490+ elif self ._hash_type == _lib .WC_HASH_TYPE_SHA512 :
491+ self ._mgf = _lib .WC_MGF1SHA512
492+ else :
493+ self ._mgf = _lib .WC_MGF1NONE
494+
495+
476496
477497 class RsaPublic (_Rsa ):
478- def __init__ (self , key = None ):
498+ def __init__ (self , key = None , hash_type = None ):
479499 if key != None :
480500 key = t2b (key )
501+ self ._hash_type = hash_type
481502
482503 _Rsa .__init__ (self )
483504
@@ -524,17 +545,18 @@ def encrypt(self, plaintext):
524545
525546 return _ffi .buffer (ciphertext )[:]
526547
527- def encrypt_oaep (self , plaintext , hash_type , mgf , label ):
548+ def encrypt_oaep (self , plaintext , label = "" ):
528549 plaintext = t2b (plaintext )
529550 label = t2b (label )
530551 ciphertext = _ffi .new ("byte[%d]" % self .output_size )
531-
552+ if self ._mgf is None :
553+ self ._get_mgf ()
532554 ret = _lib .wc_RsaPublicEncrypt_ex (plaintext , len (plaintext ),
533555 ciphertext , self .output_size ,
534556 self .native_object ,
535557 self ._random .native_object ,
536- _lib .WC_RSA_OAEP_PAD , hash_type ,
537- mgf , label , len (label ))
558+ _lib .WC_RSA_OAEP_PAD , self . _hash_type ,
559+ self . _mgf , label , len (label ))
538560
539561 if ret != self .output_size : # pragma: no cover
540562 raise WolfCryptError ("Encryption error (%d)" % ret )
@@ -563,7 +585,7 @@ def verify(self, signature):
563585 return _ffi .buffer (plaintext , ret )[:]
564586
565587 if _lib .RSA_PSS_ENABLED :
566- def verify_pss (self , plaintext , signature , hash_type , mgf ):
588+ def verify_pss (self , plaintext , signature ):
567589 """
568590 Verifies **signature**, using the public key data in the
569591 object. The signature's length must be equal to:
@@ -574,17 +596,19 @@ def verify_pss(self, plaintext, signature, hash_type, mgf):
574596 """
575597 plaintext = t2b (plaintext )
576598 signature = t2b (signature )
599+ if self ._mgf is None :
600+ self ._get_mgf ()
577601 verify = _ffi .new ("byte[%d]" % self .output_size )
578602
579603 ret = _lib .wc_RsaPSS_Verify (signature , len (signature ),
580604 verify , self .output_size ,
581- hash_type , mgf ,
605+ self . _hash_type , self . _mgf ,
582606 self .native_object )
583607
584608 if ret < 0 : # pragma: no cover
585609 raise WolfCryptError ("Verify error (%d)" % ret )
586610 ret = _lib .wc_RsaPSS_CheckPadding (plaintext , len (plaintext ),
587- verify , ret , hash_type )
611+ verify , ret , self . _hash_type )
588612
589613 return ret
590614
@@ -613,10 +637,10 @@ def make_key(cls, size, rng=Random()):
613637
614638 return rsa
615639
616- def __init__ (self , key = None ): # pylint: disable=super-init-not-called
640+ def __init__ (self , key = None , hash_type = None ): # pylint: disable=super-init-not-called
617641
618642 _Rsa .__init__ (self ) # pylint: disable=non-parent-init-called
619-
643+ self . _hash_type = hash_type
620644 idx = _ffi .new ("word32*" )
621645 idx [0 ] = 0
622646
@@ -692,7 +716,7 @@ def decrypt(self, ciphertext):
692716
693717 return _ffi .buffer (plaintext , ret )[:]
694718
695- def decrypt_oaep (self , ciphertext , hash_type , mgf , label ):
719+ def decrypt_oaep (self , ciphertext , label = "" ):
696720 """
697721 Decrypts **ciphertext**, using the private key data in the
698722 object. The ciphertext's length must be equal to:
@@ -704,11 +728,13 @@ def decrypt_oaep(self, ciphertext, hash_type, mgf, label):
704728 ciphertext = t2b (ciphertext )
705729 label = t2b (label )
706730 plaintext = _ffi .new ("byte[%d]" % self .output_size )
731+ if self ._mgf is None :
732+ self ._get_mgf ()
707733 ret = _lib .wc_RsaPrivateDecrypt_ex (ciphertext , len (ciphertext ),
708734 plaintext , self .output_size ,
709735 self .native_object ,
710- _lib .WC_RSA_OAEP_PAD , hash_type ,
711- mgf , label , len (label ))
736+ _lib .WC_RSA_OAEP_PAD , self . _hash_type ,
737+ self . _mgf , label , len (label ))
712738
713739 if ret < 0 : # pragma: no cover
714740 raise WolfCryptError ("Decryption error (%d)" % ret )
@@ -738,7 +764,7 @@ def sign(self, plaintext):
738764 return _ffi .buffer (signature , self .output_size )[:]
739765
740766 if _lib .RSA_PSS_ENABLED :
741- def sign_pss (self , plaintext , hash_type , mgf ):
767+ def sign_pss (self , plaintext ):
742768 """
743769 Signs **plaintext**, using the private key data in the object.
744770 The plaintext's length must not be greater than:
@@ -749,10 +775,11 @@ def sign_pss(self, plaintext, hash_type, mgf):
749775 """
750776 plaintext = t2b (plaintext )
751777 signature = _ffi .new ("byte[%d]" % self .output_size )
752-
778+ if self ._mgf is None :
779+ self ._get_mgf ()
753780 ret = _lib .wc_RsaPSS_Sign (plaintext , len (plaintext ),
754781 signature , self .output_size ,
755- hash_type , mgf ,
782+ self . _hash_type , self . _mgf ,
756783 self .native_object ,
757784 self ._random .native_object )
758785
0 commit comments