@@ -98,7 +98,7 @@ pub static AES_256: HPAlgorithm = HPAlgorithm {
9898} ;
9999
100100fn init_hp_aes_cipher ( key : & [ u8 ] ) -> Result < Cipher , Error > {
101- let aes_cipher = AesCipher :: default ( ) ;
101+ let mut aes_cipher = AesCipher :: default ( ) ;
102102 aes_cipher. set_key ( key) ?;
103103 Ok ( Cipher :: Aes ( aes_cipher) )
104104}
@@ -323,7 +323,7 @@ pub static AES_256_GCM: AeadAlgorithm = AeadAlgorithm {
323323} ;
324324
325325fn init_aes_gcm_cipher ( key : & [ u8 ] ) -> Result < Cipher , Error > {
326- let aes_cipher = AesCipher :: default ( ) ;
326+ let mut aes_cipher = AesCipher :: default ( ) ;
327327 aes_cipher. set_key ( key) ?;
328328 Ok ( Cipher :: Aes ( aes_cipher) )
329329}
@@ -527,7 +527,8 @@ impl quic::Algorithm for KeyFactory {
527527}
528528
529529pub struct AesCipher {
530- aes_cipher : AesObject ,
530+ aes_object : AesObject ,
531+ key : Vec < u8 > ,
531532}
532533
533534impl Default for AesCipher {
@@ -539,25 +540,27 @@ impl Default for AesCipher {
539540impl AesCipher {
540541 pub fn new ( ) -> Self {
541542 Self {
542- aes_cipher : new_aes_cipher ( ) . unwrap ( ) ,
543+ aes_object : new_aes_object ( ) . unwrap ( ) ,
544+ key : Vec :: new ( ) ,
543545 }
544546 }
545547
546- /// It initializes an AES object with the given key.
547- pub fn set_key ( & self , key : & [ u8 ] ) -> Result < ( ) , Error > {
548+ /// It initializes an AES cipher with the given key.
549+ pub fn set_key ( & mut self , key : & [ u8 ] ) -> Result < ( ) , Error > {
548550 if key. len ( ) != AES_256_KEY_LEN && key. len ( ) != AES_128_KEY_LEN {
549551 return Err ( Error :: General ( "Invalid key length" . into ( ) ) ) ;
550552 }
551553 let ret = unsafe {
552554 wc_AesSetKey (
553- self . aes_cipher . as_ptr ( ) ,
555+ self . aes_object . as_ptr ( ) ,
554556 key. as_ptr ( ) ,
555557 key. len ( ) as word32 ,
556558 ptr:: null_mut ( ) ,
557559 0 ,
558560 )
559561 } ;
560562 check_if_zero ( ret) . unwrap ( ) ;
563+ self . key = key. to_vec ( ) ;
561564 Ok ( ( ) )
562565 }
563566
@@ -566,7 +569,7 @@ impl AesCipher {
566569
567570 let ret = unsafe {
568571 wc_AesEncryptDirect (
569- self . aes_cipher . as_ptr ( ) ,
572+ self . aes_object . as_ptr ( ) ,
570573 out_block. as_mut_ptr ( ) ,
571574 sample. as_ptr ( ) ,
572575 )
@@ -583,16 +586,29 @@ impl AesCipher {
583586 payload : & mut [ u8 ] ,
584587 ) -> Result < Tag , Error > {
585588 let mut auth_tag = vec ! [ 0u8 ; TAG_LEN ] ;
589+ let mut ret;
590+
591+ // Prepare aes_object for encryption
592+ ret = unsafe {
593+ wc_AesSetKey (
594+ self . aes_object . as_ptr ( ) ,
595+ self . key . as_ptr ( ) ,
596+ self . key . len ( ) as word32 ,
597+ ptr:: null_mut ( ) ,
598+ 0 ,
599+ )
600+ } ;
601+ check_if_zero ( ret) . unwrap ( ) ;
586602
587603 // This function encrypts the input message, held in the buffer in,
588604 // and stores the resulting cipher text in the output buffer out.
589605 // It requires a new iv (initialization vector) for each call to encrypt.
590606 // It also encodes the input authentication vector,
591607 // authIn, into the authentication tag, authTag.
592608
593- let ret = unsafe {
609+ ret = unsafe {
594610 wc_AesGcmEncrypt (
595- self . aes_cipher . as_ptr ( ) ,
611+ self . aes_object . as_ptr ( ) ,
596612 payload. as_mut_ptr ( ) ,
597613 payload. as_ptr ( ) ,
598614 payload. as_ref ( ) . len ( ) as word32 ,
@@ -613,11 +629,25 @@ impl AesCipher {
613629 let message_len = payload. len ( ) - TAG_LEN ;
614630 auth_tag. copy_from_slice ( & payload[ message_len..] ) ;
615631
632+ let mut ret;
633+
634+ // Prepare aes_object for decryption
635+ ret = unsafe {
636+ wc_AesSetKey (
637+ self . aes_object . as_ptr ( ) ,
638+ self . key . as_ptr ( ) ,
639+ self . key . len ( ) as word32 ,
640+ ptr:: null_mut ( ) ,
641+ 0 ,
642+ )
643+ } ;
644+ check_if_zero ( ret) . unwrap ( ) ;
645+
616646 // Finally, we have everything to decrypt the message
617647 // from the payload.
618- let ret = unsafe {
648+ ret = unsafe {
619649 wc_AesGcmDecrypt (
620- self . aes_cipher . as_ptr ( ) ,
650+ self . aes_object . as_ptr ( ) ,
621651 payload[ ..message_len] . as_mut_ptr ( ) ,
622652 payload[ ..message_len] . as_ptr ( ) ,
623653 payload[ ..message_len] . len ( ) . try_into ( ) . unwrap ( ) ,
@@ -644,7 +674,7 @@ impl ChaChaCipher {
644674 pub fn new ( key : Option < [ u8 ; CHACHA_KEY_LEN ] > ) -> Self {
645675 match key {
646676 None => Self {
647- chacha_cipher : Some ( new_chacha_cipher ( ) . unwrap ( ) ) ,
677+ chacha_cipher : Some ( new_chacha_object ( ) . unwrap ( ) ) ,
648678 key : None ,
649679 } ,
650680 Some ( key_bytes) => Self {
@@ -771,7 +801,7 @@ impl ChaChaCipher {
771801 }
772802}
773803
774- fn new_aes_cipher ( ) -> Result < AesObject , Error > {
804+ fn new_aes_object ( ) -> Result < AesObject , Error > {
775805 let aes_c_type_box = Box :: new ( unsafe { mem:: zeroed :: < Aes > ( ) } ) ;
776806 let aes_c_type_ptr = Box :: into_raw ( aes_c_type_box) ;
777807 let aes_object = unsafe { AesObject :: from_ptr ( aes_c_type_ptr) } ;
@@ -782,7 +812,7 @@ fn new_aes_cipher() -> Result<AesObject, Error> {
782812 Ok ( aes_object)
783813}
784814
785- fn new_chacha_cipher ( ) -> Result < ChaChaObject , Error > {
815+ fn new_chacha_object ( ) -> Result < ChaChaObject , Error > {
786816 //Create ChaCha object
787817 let chacha_c_typ_box = Box :: new ( unsafe { mem:: zeroed :: < ChaCha > ( ) } ) ;
788818 let chacha_c_typ_ptr = Box :: into_raw ( chacha_c_typ_box) ;
@@ -803,7 +833,7 @@ mod tests {
803833 use rustls:: crypto:: cipher:: { AeadKey , Iv , NONCE_LEN } ;
804834 use rustls:: quic:: * ;
805835
806- use crate :: provider ;
836+ use crate :: default_provider ;
807837 use crate :: { TLS13_AES_128_GCM_SHA256 , TLS13_CHACHA20_POLY1305_SHA256 } ;
808838 use rustls:: crypto:: tls13:: OkmBlock ;
809839 use rustls:: internal:: msgs:: codec:: Codec ;
@@ -836,7 +866,7 @@ mod tests {
836866 let root_store =
837867 rustls:: RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
838868
839- let config = rustls:: ClientConfig :: builder_with_provider ( provider ( ) . into ( ) )
869+ let config = rustls:: ClientConfig :: builder_with_provider ( default_provider ( ) . into ( ) )
840870 . with_safe_default_protocol_versions ( )
841871 . unwrap ( )
842872 . with_root_certificates ( root_store)
@@ -871,7 +901,7 @@ mod tests {
871901 . signed_by ( & server_key, & ca_cert, & ca_key)
872902 . unwrap ( ) ;
873903
874- let mut server_config = ServerConfig :: builder_with_provider ( provider ( ) . into ( ) )
904+ let mut server_config = ServerConfig :: builder_with_provider ( default_provider ( ) . into ( ) )
875905 . with_safe_default_protocol_versions ( )
876906 . unwrap ( )
877907 . with_no_client_auth ( )
@@ -1563,7 +1593,7 @@ mod tests {
15631593 } ,
15641594 ] ;
15651595
1566- let aes_cipher = crate :: aead:: quic:: AesCipher :: default ( ) ;
1596+ let mut aes_cipher = crate :: aead:: quic:: AesCipher :: default ( ) ;
15671597 let mut mask = [ 0u8 ; 5 ] ;
15681598
15691599 for v in & vectors {
0 commit comments