@@ -368,7 +368,9 @@ pub static CHACHA20_POLY1305: AeadAlgorithm = AeadAlgorithm {
368368} ;
369369
370370fn init_chacha20_poly1305_cipher ( key : & [ u8 ] ) -> Result < Cipher , Error > {
371- let chacha_cipher = ChaChaCipher :: new ( Some ( <[ u8 ; 32 ] >:: try_from ( key) . unwrap ( ) ) ) ;
371+ let key_array = <[ u8 ; 32 ] >:: try_from ( key)
372+ . map_err ( |_| Error :: General ( "Invalid key length for ChaCha20-Poly1305" . into ( ) ) ) ?;
373+ let chacha_cipher = ChaChaCipher :: new ( Some ( key_array) ) ;
372374 Ok ( Cipher :: ChaCha20 ( chacha_cipher) )
373375}
374376
@@ -423,7 +425,7 @@ impl PacketKey {
423425 return Err ( Error :: General ( "Invalid key length" . into ( ) ) ) ;
424426 }
425427 Ok ( Self {
426- packet_cipher : ( algorithm. init ) ( key. as_ref ( ) ) . unwrap ( ) ,
428+ packet_cipher : ( algorithm. init ) ( key. as_ref ( ) ) ? ,
427429 iv,
428430 confidentiality_limit,
429431 integrity_limit,
@@ -493,19 +495,26 @@ pub(crate) struct KeyFactory {
493495impl quic:: Algorithm for KeyFactory {
494496 fn packet_key ( & self , key : AeadKey , iv : Iv ) -> Box < dyn quic:: PacketKey > {
495497 Box :: new (
496- PacketKey :: new (
498+ match PacketKey :: new (
497499 key,
498500 iv,
499501 self . confidentiality_limit ,
500502 self . integrity_limit ,
501503 self . packet_algo ,
502- )
503- . unwrap ( ) ,
504+ ) {
505+ Ok ( packet_key) => packet_key,
506+ Err ( e) => panic ! ( "PacketKey object creation failed: {:?}" , e) ,
507+ } ,
504508 )
505509 }
506510
507511 fn header_protection_key ( & self , key : AeadKey ) -> Box < dyn quic:: HeaderProtectionKey > {
508- Box :: new ( HeaderProtectionKey :: new ( key. as_ref ( ) . to_vec ( ) , self . header_algo ) . unwrap ( ) )
512+ Box :: new (
513+ match HeaderProtectionKey :: new ( key. as_ref ( ) . to_vec ( ) , self . header_algo ) {
514+ Ok ( header_key) => header_key,
515+ Err ( e) => panic ! ( "HeaderProtection Key object creation failed: {:?}" , e) ,
516+ } ,
517+ )
509518 }
510519
511520 fn aead_key_len ( & self ) -> usize {
@@ -649,6 +658,12 @@ impl ChaChaCipher {
649658 if key. len ( ) != CHACHA_KEY_LEN {
650659 return Err ( Error :: General ( "Invalid key length" . into ( ) ) ) ;
651660 }
661+
662+ if self . chacha_cipher . is_none ( ) {
663+ return Err ( Error :: General (
664+ "Cipher is none. Create a cipher object before setting key" . into ( ) ,
665+ ) ) ;
666+ }
652667 //Set key for ChaCha object
653668 let ret = unsafe {
654669 wc_Chacha_SetKey (
@@ -666,6 +681,12 @@ impl ChaChaCipher {
666681 }
667682
668683 pub fn encrypt_sample ( & self , sample : & [ u8 ] ) -> Result < Vec < u8 > , Error > {
684+ if self . chacha_cipher . is_none ( ) {
685+ return Err ( Error :: General (
686+ "Cipher is none. Create a cipher object before encryption" . into ( ) ,
687+ ) ) ;
688+ }
689+
669690 let mut out = vec ! [ 0 ; TAG_LEN ] ;
670691
671692 let ( ctr, nonce) = sample. split_at ( 4 ) ;
@@ -772,23 +793,23 @@ fn new_chacha_cipher() -> Result<ChaChaObject, Error> {
772793
773794#[ cfg( test) ]
774795mod tests {
775- use std:: prelude:: rust_2015:: ToString ;
776796 use hex_literal:: hex;
777797 use rustls:: crypto:: tls13:: HkdfExpander ;
798+ use std:: prelude:: rust_2015:: ToString ;
778799 use std:: prelude:: v1:: Vec ;
779800 use std:: vec;
780801
781802 use crate :: aead;
782803 use rustls:: crypto:: cipher:: { AeadKey , Iv , NONCE_LEN } ;
783804 use rustls:: quic:: * ;
784805
806+ use crate :: provider;
785807 use crate :: { TLS13_AES_128_GCM_SHA256 , TLS13_CHACHA20_POLY1305_SHA256 } ;
786808 use rustls:: crypto:: tls13:: OkmBlock ;
787- use rustls:: { ClientConfig , ServerConfig , Side , SideData , Error } ;
788- use std:: sync:: Arc ;
789809 use rustls:: internal:: msgs:: codec:: Codec ;
810+ use rustls:: { ClientConfig , Error , ServerConfig , Side , SideData } ;
790811 use rustls_pki_types:: PrivatePkcs8KeyDer ;
791- use crate :: provider ;
812+ use std :: sync :: Arc ;
792813
793814 // Returns the sender's next secrets to use, or the receiver's error.
794815 fn step < L : SideData , R : SideData > (
@@ -854,15 +875,18 @@ mod tests {
854875 . with_safe_default_protocol_versions ( )
855876 . unwrap ( )
856877 . with_no_client_auth ( )
857- . with_single_cert ( vec ! [ server_cert. into( ) ] , PrivatePkcs8KeyDer :: from ( server_key. serialize_der ( ) ) . into ( ) )
878+ . with_single_cert (
879+ vec ! [ server_cert. into( ) ] ,
880+ PrivatePkcs8KeyDer :: from ( server_key. serialize_der ( ) ) . into ( ) ,
881+ )
858882 . unwrap ( ) ;
859883
860884 server_config. key_log = Arc :: new ( rustls:: KeyLogFile :: new ( ) ) ;
861885
862886 server_config
863887 }
864888 /// Encode each of `items`
865- pub fn iter_to_vec_of_bytes < ' a , T : Codec < ' a > > ( items : impl Iterator < Item = T > ) -> Vec < u8 > {
889+ pub fn iter_to_vec_of_bytes < ' a , T : Codec < ' a > > ( items : impl Iterator < Item = T > ) -> Vec < u8 > {
866890 let mut body = Vec :: new ( ) ;
867891
868892 for i in items {
@@ -874,15 +898,19 @@ mod tests {
874898 ///Encode length as prefix
875899 pub fn prefix_len ( mut body : Vec < u8 > , len : usize ) -> Vec < u8 > {
876900 match len {
877- 8 => { body. splice ( 0 ..0 , [ body. len ( ) as u8 ] ) ; }
878- 16 => { body. splice ( 0 ..0 , ( body. len ( ) as u16 ) . to_be_bytes ( ) ) ; }
901+ 8 => {
902+ body. splice ( 0 ..0 , [ body. len ( ) as u8 ] ) ;
903+ }
904+ 16 => {
905+ body. splice ( 0 ..0 , ( body. len ( ) as u16 ) . to_be_bytes ( ) ) ;
906+ }
879907 24 => {
880908 let len = ( body. len ( ) as u32 ) . to_be_bytes ( ) ;
881909 body. insert ( 0 , len[ 1 ] ) ;
882910 body. insert ( 1 , len[ 2 ] ) ;
883911 body. insert ( 2 , len[ 3 ] ) ;
884912 }
885- _ => panic ! ( "wrong length!" )
913+ _ => panic ! ( "wrong length!" ) ,
886914 } ;
887915 body
888916 }
@@ -893,7 +921,10 @@ mod tests {
893921 // kx group
894922 extensions. push ( Extension {
895923 typ : 0x000a , // EllipticCurves
896- body : prefix_len ( iter_to_vec_of_bytes ( [ rustls:: NamedGroup :: secp256r1] . into_iter ( ) ) , 16 ) ,
924+ body : prefix_len (
925+ iter_to_vec_of_bytes ( [ rustls:: NamedGroup :: secp256r1] . into_iter ( ) ) ,
926+ 16 ,
927+ ) ,
897928 } ) ;
898929 // Sig algs
899930 extensions. push ( Extension {
@@ -909,23 +940,29 @@ mod tests {
909940 // Supported Versions,
910941 extensions. push ( Extension {
911942 typ : 0x002b , // Supported Versions
912- body : prefix_len ( iter_to_vec_of_bytes (
913- [ rustls:: ProtocolVersion :: TLSv1_3 , rustls:: ProtocolVersion :: TLSv1_2 ] . into_iter ( ) ,
914- ) , 8 ) ,
943+ body : prefix_len (
944+ iter_to_vec_of_bytes (
945+ [
946+ rustls:: ProtocolVersion :: TLSv1_3 ,
947+ rustls:: ProtocolVersion :: TLSv1_2 ,
948+ ]
949+ . into_iter ( ) ,
950+ ) ,
951+ 8 ,
952+ ) ,
915953 } ) ;
916954
917955 // Key share
918956 const SOME_POINT_ON_P256 : & [ u8 ] = & [
919- 4 , 41 , 39 , 177 , 5 , 18 , 186 , 227 , 237 , 220 , 254 , 70 , 120 , 40 , 18 , 139 , 173 , 41 , 3 ,
920- 38 , 153 , 25 , 247 , 8 , 96 , 105 , 200 , 196 , 223 , 108 , 115 , 40 , 56 , 199 , 120 , 121 , 100 ,
921- 234 , 172 , 0 , 229 , 146 , 31 , 177 , 73 , 138 , 96 , 244 , 96 , 103 , 102 , 179 , 217 , 104 , 80 ,
922- 1 , 85 , 141 , 26 , 151 , 78 , 115 , 65 , 81 , 62 ,
957+ 4 , 41 , 39 , 177 , 5 , 18 , 186 , 227 , 237 , 220 , 254 , 70 , 120 , 40 , 18 , 139 , 173 , 41 , 3 , 38 ,
958+ 153 , 25 , 247 , 8 , 96 , 105 , 200 , 196 , 223 , 108 , 115 , 40 , 56 , 199 , 120 , 121 , 100 , 234 ,
959+ 172 , 0 , 229 , 146 , 31 , 177 , 73 , 138 , 96 , 244 , 96 , 103 , 102 , 179 , 217 , 104 , 80 , 1 , 85 ,
960+ 141 , 26 , 151 , 78 , 115 , 65 , 81 , 62 ,
923961 ] ;
924962
925963 let mut share = prefix_len ( SOME_POINT_ON_P256 . to_vec ( ) , 16 ) ;
926964 share. splice ( 0 ..0 , rustls:: NamedGroup :: secp256r1. to_array ( ) ) ;
927965
928-
929966 extensions. push ( Extension {
930967 typ : 0x0033 , // Key share
931968 body : prefix_len ( share, 16 ) ,
@@ -940,7 +977,9 @@ mod tests {
940977 vec ! [
941978 rustls:: CipherSuite :: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ,
942979 rustls:: CipherSuite :: TLS13_AES_128_GCM_SHA256 ,
943- ] . to_vec ( ) . encode ( & mut ch) ; // Encode cypher suites
980+ ]
981+ . to_vec ( )
982+ . encode ( & mut ch) ; // Encode cypher suites
944983 ch. extend_from_slice ( & [ 0x01 , 0x00 ] ) ; // only null compression
945984
946985 //Generate ch extensions
@@ -1104,7 +1143,7 @@ mod tests {
11041143 let ( first, rest) = header. split_at_mut ( 1 ) ;
11051144 let sample = & sample[ ..header_protection_key. sample_len ( ) ] ;
11061145 header_protection_key
1107- . encrypt_in_place ( sample, & mut first[ 0 ] , dbg ! ( rest) )
1146+ . encrypt_in_place ( sample, & mut first[ 0 ] , rest)
11081147 . unwrap ( ) ;
11091148
11101149 assert_eq ! ( & buf, expected) ;
@@ -1209,7 +1248,6 @@ mod tests {
12091248 assert_eq ! ( server_packet[ ..] , expected_server_packet[ ..] ) ;
12101249 }
12111250
1212-
12131251 #[ test]
12141252 fn test_quic_rejects_missing_alpn ( ) {
12151253 //Code taken from rustls with modification
@@ -1228,15 +1266,16 @@ mod tests {
12281266 "localhost" . try_into ( ) . unwrap ( ) ,
12291267 client_params. into ( ) ,
12301268 )
1231- . unwrap ( ) ;
1232- let mut server =
1233- rustls:: quic:: ServerConnection :: new ( server_config, rustls:: quic:: Version :: V1 , server_params. into ( ) )
1234- . unwrap ( ) ;
1269+ . unwrap ( ) ;
1270+ let mut server = rustls:: quic:: ServerConnection :: new (
1271+ server_config,
1272+ rustls:: quic:: Version :: V1 ,
1273+ server_params. into ( ) ,
1274+ )
1275+ . unwrap ( ) ;
12351276
12361277 assert_eq ! (
1237- step( & mut client, & mut server)
1238- . err( )
1239- . unwrap( ) ,
1278+ step( & mut client, & mut server) . err( ) . unwrap( ) ,
12401279 rustls:: Error :: NoApplicationProtocol
12411280 ) ;
12421281
@@ -1267,8 +1306,12 @@ mod tests {
12671306
12681307 let wrapped = Arc :: new ( server_config. clone ( ) ) ;
12691308 assert_eq ! (
1270- rustls:: quic:: ServerConnection :: new( wrapped, rustls:: quic:: Version :: V1 , b"server params" . to_vec( ) , )
1271- . is_ok( ) ,
1309+ rustls:: quic:: ServerConnection :: new(
1310+ wrapped,
1311+ rustls:: quic:: Version :: V1 ,
1312+ b"server params" . to_vec( ) ,
1313+ )
1314+ . is_ok( ) ,
12721315 ok
12731316 ) ;
12741317 }
@@ -1286,7 +1329,7 @@ mod tests {
12861329 rustls:: quic:: Version :: V1 ,
12871330 b"server params" . to_vec ( ) ,
12881331 )
1289- . unwrap ( ) ;
1332+ . unwrap ( ) ;
12901333
12911334 //Make a basic client hello
12921335 let ch = make_client_hello ( ) ;
@@ -1301,8 +1344,8 @@ mod tests {
13011344 #[ test]
13021345 fn packet_key_api ( ) {
13031346 //Code taken from rustls
1304- use rustls:: Side ;
13051347 use rustls:: quic:: { Keys , Version } ;
1348+ use rustls:: Side ;
13061349
13071350 // Test vectors: https://www.rfc-editor.org/rfc/rfc9001.html#name-client-initial
13081351 const CONNECTION_ID : & [ u8 ] = & [ 0x83 , 0x94 , 0xc8 , 0xf0 , 0x3e , 0x51 , 0x57 , 0x08 ] ;
@@ -1335,14 +1378,8 @@ mod tests {
13351378
13361379 let client_keys = Keys :: initial (
13371380 Version :: V1 ,
1338- TLS13_AES_128_GCM_SHA256
1339- . tls13 ( )
1340- . unwrap ( ) ,
1341- TLS13_AES_128_GCM_SHA256
1342- . tls13 ( )
1343- . unwrap ( )
1344- . quic
1345- . unwrap ( ) ,
1381+ TLS13_AES_128_GCM_SHA256 . tls13 ( ) . unwrap ( ) ,
1382+ TLS13_AES_128_GCM_SHA256 . tls13 ( ) . unwrap ( ) . quic . unwrap ( ) ,
13461383 CONNECTION_ID ,
13471384 Side :: Client ,
13481385 ) ;
@@ -1469,14 +1506,8 @@ mod tests {
14691506
14701507 let server_keys = Keys :: initial (
14711508 Version :: V1 ,
1472- TLS13_AES_128_GCM_SHA256
1473- . tls13 ( )
1474- . unwrap ( ) ,
1475- TLS13_AES_128_GCM_SHA256
1476- . tls13 ( )
1477- . unwrap ( )
1478- . quic
1479- . unwrap ( ) ,
1509+ TLS13_AES_128_GCM_SHA256 . tls13 ( ) . unwrap ( ) ,
1510+ TLS13_AES_128_GCM_SHA256 . tls13 ( ) . unwrap ( ) . quic . unwrap ( ) ,
14801511 CONNECTION_ID ,
14811512 Side :: Server ,
14821513 ) ;
0 commit comments