Skip to content

Commit d8eb6c5

Browse files
committed
Code fixes and formatting
1 parent 3b6baae commit d8eb6c5

File tree

5 files changed

+86
-61
lines changed

5 files changed

+86
-61
lines changed

rustls-wolfcrypt-provider/src/aead/quic.rs

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,9 @@ pub static CHACHA20_POLY1305: AeadAlgorithm = AeadAlgorithm {
368368
};
369369

370370
fn init_chacha20_poly1305_cipher(key: &[u8]) -> Result<Cipher, Error> {
371-
let chacha_cipher = ChaChaCipher::new(Some(<[u8; 32]>::try_from(key).unwrap()));
371+
let key_array = <[u8; 32]>::try_from(key)
372+
.map_err(|_| Error::General("Invalid key length for ChaCha20-Poly1305".into()))?;
373+
let chacha_cipher = ChaChaCipher::new(Some(key_array));
372374
Ok(Cipher::ChaCha20(chacha_cipher))
373375
}
374376

@@ -423,7 +425,7 @@ impl PacketKey {
423425
return Err(Error::General("Invalid key length".into()));
424426
}
425427
Ok(Self {
426-
packet_cipher: (algorithm.init)(key.as_ref()).unwrap(),
428+
packet_cipher: (algorithm.init)(key.as_ref())?,
427429
iv,
428430
confidentiality_limit,
429431
integrity_limit,
@@ -493,19 +495,26 @@ pub(crate) struct KeyFactory {
493495
impl quic::Algorithm for KeyFactory {
494496
fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn quic::PacketKey> {
495497
Box::new(
496-
PacketKey::new(
498+
match PacketKey::new(
497499
key,
498500
iv,
499501
self.confidentiality_limit,
500502
self.integrity_limit,
501503
self.packet_algo,
502-
)
503-
.unwrap(),
504+
) {
505+
Ok(packet_key) => packet_key,
506+
Err(e) => panic!("PacketKey object creation failed: {:?}", e),
507+
},
504508
)
505509
}
506510

507511
fn header_protection_key(&self, key: AeadKey) -> Box<dyn quic::HeaderProtectionKey> {
508-
Box::new(HeaderProtectionKey::new(key.as_ref().to_vec(), self.header_algo).unwrap())
512+
Box::new(
513+
match HeaderProtectionKey::new(key.as_ref().to_vec(), self.header_algo) {
514+
Ok(header_key) => header_key,
515+
Err(e) => panic!("HeaderProtection Key object creation failed: {:?}", e),
516+
},
517+
)
509518
}
510519

511520
fn aead_key_len(&self) -> usize {
@@ -649,6 +658,12 @@ impl ChaChaCipher {
649658
if key.len() != CHACHA_KEY_LEN {
650659
return Err(Error::General("Invalid key length".into()));
651660
}
661+
662+
if self.chacha_cipher.is_none() {
663+
return Err(Error::General(
664+
"Cipher is none. Create a cipher object before setting key".into(),
665+
));
666+
}
652667
//Set key for ChaCha object
653668
let ret = unsafe {
654669
wc_Chacha_SetKey(
@@ -666,6 +681,12 @@ impl ChaChaCipher {
666681
}
667682

668683
pub fn encrypt_sample(&self, sample: &[u8]) -> Result<Vec<u8>, Error> {
684+
if self.chacha_cipher.is_none() {
685+
return Err(Error::General(
686+
"Cipher is none. Create a cipher object before encryption".into(),
687+
));
688+
}
689+
669690
let mut out = vec![0; TAG_LEN];
670691

671692
let (ctr, nonce) = sample.split_at(4);
@@ -772,23 +793,23 @@ fn new_chacha_cipher() -> Result<ChaChaObject, Error> {
772793

773794
#[cfg(test)]
774795
mod tests {
775-
use std::prelude::rust_2015::ToString;
776796
use hex_literal::hex;
777797
use rustls::crypto::tls13::HkdfExpander;
798+
use std::prelude::rust_2015::ToString;
778799
use std::prelude::v1::Vec;
779800
use std::vec;
780801

781802
use crate::aead;
782803
use rustls::crypto::cipher::{AeadKey, Iv, NONCE_LEN};
783804
use rustls::quic::*;
784805

806+
use crate::provider;
785807
use crate::{TLS13_AES_128_GCM_SHA256, TLS13_CHACHA20_POLY1305_SHA256};
786808
use rustls::crypto::tls13::OkmBlock;
787-
use rustls::{ClientConfig, ServerConfig, Side, SideData, Error};
788-
use std::sync::Arc;
789809
use rustls::internal::msgs::codec::Codec;
810+
use rustls::{ClientConfig, Error, ServerConfig, Side, SideData};
790811
use rustls_pki_types::PrivatePkcs8KeyDer;
791-
use crate::provider;
812+
use std::sync::Arc;
792813

793814
// Returns the sender's next secrets to use, or the receiver's error.
794815
fn step<L: SideData, R: SideData>(
@@ -854,15 +875,18 @@ mod tests {
854875
.with_safe_default_protocol_versions()
855876
.unwrap()
856877
.with_no_client_auth()
857-
.with_single_cert(vec![server_cert.into()], PrivatePkcs8KeyDer::from(server_key.serialize_der()).into())
878+
.with_single_cert(
879+
vec![server_cert.into()],
880+
PrivatePkcs8KeyDer::from(server_key.serialize_der()).into(),
881+
)
858882
.unwrap();
859883

860884
server_config.key_log = Arc::new(rustls::KeyLogFile::new());
861885

862886
server_config
863887
}
864888
/// Encode each of `items`
865-
pub fn iter_to_vec_of_bytes<'a, T: Codec<'a>>(items: impl Iterator<Item=T>) -> Vec<u8> {
889+
pub fn iter_to_vec_of_bytes<'a, T: Codec<'a>>(items: impl Iterator<Item = T>) -> Vec<u8> {
866890
let mut body = Vec::new();
867891

868892
for i in items {
@@ -874,15 +898,19 @@ mod tests {
874898
///Encode length as prefix
875899
pub fn prefix_len(mut body: Vec<u8>, len: usize) -> Vec<u8> {
876900
match len {
877-
8 => { body.splice(0..0, [body.len() as u8]); }
878-
16 => { body.splice(0..0, (body.len() as u16).to_be_bytes()); }
901+
8 => {
902+
body.splice(0..0, [body.len() as u8]);
903+
}
904+
16 => {
905+
body.splice(0..0, (body.len() as u16).to_be_bytes());
906+
}
879907
24 => {
880908
let len = (body.len() as u32).to_be_bytes();
881909
body.insert(0, len[1]);
882910
body.insert(1, len[2]);
883911
body.insert(2, len[3]);
884912
}
885-
_ => panic!("wrong length!")
913+
_ => panic!("wrong length!"),
886914
};
887915
body
888916
}
@@ -893,7 +921,10 @@ mod tests {
893921
// kx group
894922
extensions.push(Extension {
895923
typ: 0x000a, // EllipticCurves
896-
body: prefix_len(iter_to_vec_of_bytes([rustls::NamedGroup::secp256r1].into_iter()), 16),
924+
body: prefix_len(
925+
iter_to_vec_of_bytes([rustls::NamedGroup::secp256r1].into_iter()),
926+
16,
927+
),
897928
});
898929
// Sig algs
899930
extensions.push(Extension {
@@ -909,23 +940,29 @@ mod tests {
909940
// Supported Versions,
910941
extensions.push(Extension {
911942
typ: 0x002b, // Supported Versions
912-
body: prefix_len(iter_to_vec_of_bytes(
913-
[rustls::ProtocolVersion::TLSv1_3, rustls::ProtocolVersion::TLSv1_2].into_iter(),
914-
), 8),
943+
body: prefix_len(
944+
iter_to_vec_of_bytes(
945+
[
946+
rustls::ProtocolVersion::TLSv1_3,
947+
rustls::ProtocolVersion::TLSv1_2,
948+
]
949+
.into_iter(),
950+
),
951+
8,
952+
),
915953
});
916954

917955
// Key share
918956
const SOME_POINT_ON_P256: &[u8] = &[
919-
4, 41, 39, 177, 5, 18, 186, 227, 237, 220, 254, 70, 120, 40, 18, 139, 173, 41, 3,
920-
38, 153, 25, 247, 8, 96, 105, 200, 196, 223, 108, 115, 40, 56, 199, 120, 121, 100,
921-
234, 172, 0, 229, 146, 31, 177, 73, 138, 96, 244, 96, 103, 102, 179, 217, 104, 80,
922-
1, 85, 141, 26, 151, 78, 115, 65, 81, 62,
957+
4, 41, 39, 177, 5, 18, 186, 227, 237, 220, 254, 70, 120, 40, 18, 139, 173, 41, 3, 38,
958+
153, 25, 247, 8, 96, 105, 200, 196, 223, 108, 115, 40, 56, 199, 120, 121, 100, 234,
959+
172, 0, 229, 146, 31, 177, 73, 138, 96, 244, 96, 103, 102, 179, 217, 104, 80, 1, 85,
960+
141, 26, 151, 78, 115, 65, 81, 62,
923961
];
924962

925963
let mut share = prefix_len(SOME_POINT_ON_P256.to_vec(), 16);
926964
share.splice(0..0, rustls::NamedGroup::secp256r1.to_array());
927965

928-
929966
extensions.push(Extension {
930967
typ: 0x0033, // Key share
931968
body: prefix_len(share, 16),
@@ -940,7 +977,9 @@ mod tests {
940977
vec![
941978
rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
942979
rustls::CipherSuite::TLS13_AES_128_GCM_SHA256,
943-
].to_vec().encode(&mut ch); // Encode cypher suites
980+
]
981+
.to_vec()
982+
.encode(&mut ch); // Encode cypher suites
944983
ch.extend_from_slice(&[0x01, 0x00]); // only null compression
945984

946985
//Generate ch extensions
@@ -1104,7 +1143,7 @@ mod tests {
11041143
let (first, rest) = header.split_at_mut(1);
11051144
let sample = &sample[..header_protection_key.sample_len()];
11061145
header_protection_key
1107-
.encrypt_in_place(sample, &mut first[0], dbg!(rest))
1146+
.encrypt_in_place(sample, &mut first[0], rest)
11081147
.unwrap();
11091148

11101149
assert_eq!(&buf, expected);
@@ -1209,7 +1248,6 @@ mod tests {
12091248
assert_eq!(server_packet[..], expected_server_packet[..]);
12101249
}
12111250

1212-
12131251
#[test]
12141252
fn test_quic_rejects_missing_alpn() {
12151253
//Code taken from rustls with modification
@@ -1228,15 +1266,16 @@ mod tests {
12281266
"localhost".try_into().unwrap(),
12291267
client_params.into(),
12301268
)
1231-
.unwrap();
1232-
let mut server =
1233-
rustls::quic::ServerConnection::new(server_config, rustls::quic::Version::V1, server_params.into())
1234-
.unwrap();
1269+
.unwrap();
1270+
let mut server = rustls::quic::ServerConnection::new(
1271+
server_config,
1272+
rustls::quic::Version::V1,
1273+
server_params.into(),
1274+
)
1275+
.unwrap();
12351276

12361277
assert_eq!(
1237-
step(&mut client, &mut server)
1238-
.err()
1239-
.unwrap(),
1278+
step(&mut client, &mut server).err().unwrap(),
12401279
rustls::Error::NoApplicationProtocol
12411280
);
12421281

@@ -1267,8 +1306,12 @@ mod tests {
12671306

12681307
let wrapped = Arc::new(server_config.clone());
12691308
assert_eq!(
1270-
rustls::quic::ServerConnection::new(wrapped, rustls::quic::Version::V1, b"server params".to_vec(), )
1271-
.is_ok(),
1309+
rustls::quic::ServerConnection::new(
1310+
wrapped,
1311+
rustls::quic::Version::V1,
1312+
b"server params".to_vec(),
1313+
)
1314+
.is_ok(),
12721315
ok
12731316
);
12741317
}
@@ -1286,7 +1329,7 @@ mod tests {
12861329
rustls::quic::Version::V1,
12871330
b"server params".to_vec(),
12881331
)
1289-
.unwrap();
1332+
.unwrap();
12901333

12911334
//Make a basic client hello
12921335
let ch = make_client_hello();
@@ -1301,8 +1344,8 @@ mod tests {
13011344
#[test]
13021345
fn packet_key_api() {
13031346
//Code taken from rustls
1304-
use rustls::Side;
13051347
use rustls::quic::{Keys, Version};
1348+
use rustls::Side;
13061349

13071350
// Test vectors: https://www.rfc-editor.org/rfc/rfc9001.html#name-client-initial
13081351
const CONNECTION_ID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
@@ -1335,14 +1378,8 @@ mod tests {
13351378

13361379
let client_keys = Keys::initial(
13371380
Version::V1,
1338-
TLS13_AES_128_GCM_SHA256
1339-
.tls13()
1340-
.unwrap(),
1341-
TLS13_AES_128_GCM_SHA256
1342-
.tls13()
1343-
.unwrap()
1344-
.quic
1345-
.unwrap(),
1381+
TLS13_AES_128_GCM_SHA256.tls13().unwrap(),
1382+
TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(),
13461383
CONNECTION_ID,
13471384
Side::Client,
13481385
);
@@ -1469,14 +1506,8 @@ mod tests {
14691506

14701507
let server_keys = Keys::initial(
14711508
Version::V1,
1472-
TLS13_AES_128_GCM_SHA256
1473-
.tls13()
1474-
.unwrap(),
1475-
TLS13_AES_128_GCM_SHA256
1476-
.tls13()
1477-
.unwrap()
1478-
.quic
1479-
.unwrap(),
1509+
TLS13_AES_128_GCM_SHA256.tls13().unwrap(),
1510+
TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(),
14801511
CONNECTION_ID,
14811512
Side::Server,
14821513
);

rustls-wolfcrypt-provider/src/error.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ pub fn check_if_greater_than_zero(ret: i32) -> WCResult {
110110
}
111111
}
112112

113-
114-
115113
#[cfg(test)]
116114
mod tests {
117115
use super::*;

rustls-wolfcrypt-provider/src/hkdf.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ impl tls13::HkdfExpander for WolfHkdfExpander {
147147
}
148148
}
149149

150-
151-
152150
#[cfg(test)]
153151
mod tests {
154152
use super::*;

rustls-wolfcrypt-provider/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ pub static TLS13_CHACHA20_POLY1305_SHA256: rustls::SupportedCipherSuite =
154154
#[cfg(not(feature = "quic"))]
155155
quic: None,
156156
#[cfg(feature = "quic")]
157-
quic:Some(&KeyFactory {
157+
quic: Some(&KeyFactory {
158158
packet_algo: &aead::quic::CHACHA20_POLY1305,
159159
header_algo: &aead::quic::CHACHA20,
160160
// ref: <https://datatracker.ietf.org/doc/html/rfc9001#section-6.6>

rustls-wolfcrypt-provider/src/types/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ macro_rules! define_foreign_type_with_copy {
123123
};
124124
}
125125

126-
127-
128126
define_foreign_type!(
129127
WCRngObject,
130128
WCRngObjectRef,

0 commit comments

Comments
 (0)