Skip to content

Commit e96ab68

Browse files
committed
- Complete refactor of the rsapss.rs and rsapkcs1.rs modules into one
single module called rsa.rs to better handle the different private key formats (with matching via SignatureScheme); - Completed support for PKCS8 for RSA-PKCS1 and PKCS1 for RSA-PSS;
1 parent 0776339 commit e96ab68

File tree

5 files changed

+302
-432
lines changed

5 files changed

+302
-432
lines changed

rustls-wolfcrypt-provider/src/lib.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ pub mod aead {
2727
pub mod sign {
2828
pub mod ecdsa;
2929
pub mod eddsa;
30-
pub mod rsapkcs1;
31-
pub mod rsapss;
30+
pub mod rsa;
3231
}
3332
use crate::aead::{aes128gcm, aes256gcm, chacha20};
3433

@@ -94,10 +93,7 @@ impl rustls::crypto::KeyProvider for Provider {
9493
// Define supported algorithms as closures
9594
let algorithms: SigningAlgorithms = vec![
9695
Box::new(|key| sign::ecdsa::EcdsaSigningKey::try_from(key).map(|x| Arc::new(x) as _)),
97-
Box::new(|key| sign::rsapss::RsaPssPrivateKey::try_from(key).map(|x| Arc::new(x) as _)),
98-
Box::new(|key| {
99-
sign::rsapkcs1::RsaPkcs1PrivateKey::try_from(key).map(|x| Arc::new(x) as _)
100-
}),
96+
Box::new(|key| sign::rsa::RsaPrivateKey::try_from(key).map(|x| Arc::new(x) as _)),
10197
Box::new(|key| sign::eddsa::Ed25519PrivateKey::try_from(key).map(|x| Arc::new(x) as _)),
10298
];
10399

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
use crate::error::*;
2+
use crate::types::*;
3+
use alloc::boxed::Box;
4+
use alloc::sync::Arc;
5+
use alloc::vec;
6+
use alloc::vec::Vec;
7+
use core::mem;
8+
use foreign_types::ForeignType;
9+
use rustls::pki_types::PrivateKeyDer;
10+
use rustls::sign::{Signer, SigningKey};
11+
use rustls::{SignatureAlgorithm, SignatureScheme};
12+
13+
use core::ptr;
14+
use wolfcrypt_rs::*;
15+
16+
const ALL_RSA_SCHEMES: &[SignatureScheme] = &[
17+
SignatureScheme::RSA_PSS_SHA256,
18+
SignatureScheme::RSA_PSS_SHA384,
19+
SignatureScheme::RSA_PSS_SHA512,
20+
SignatureScheme::RSA_PKCS1_SHA256,
21+
SignatureScheme::RSA_PKCS1_SHA384,
22+
SignatureScheme::RSA_PKCS1_SHA512,
23+
];
24+
25+
const MAX_RSA_SIG_SIZE: usize = 512;
26+
const HASH_TYPE_SHA256: u32 = wc_HashType_WC_HASH_TYPE_SHA256;
27+
const HASH_TYPE_SHA384: u32 = wc_HashType_WC_HASH_TYPE_SHA384;
28+
const HASH_TYPE_SHA512: u32 = wc_HashType_WC_HASH_TYPE_SHA512;
29+
30+
const MGF1_SHA256: u32 = WC_MGF1SHA256;
31+
const MGF1_SHA384: u32 = WC_MGF1SHA384;
32+
const MGF1_SHA512: u32 = WC_MGF1SHA512;
33+
34+
#[derive(Clone, Debug)]
35+
pub struct RsaPrivateKey {
36+
key: Arc<RsaKeyObject>,
37+
algo: SignatureAlgorithm,
38+
}
39+
40+
impl RsaPrivateKey {
41+
pub fn get_key(&self) -> Arc<RsaKeyObject> {
42+
Arc::clone(&self.key)
43+
}
44+
}
45+
46+
impl TryFrom<&PrivateKeyDer<'_>> for RsaPrivateKey {
47+
type Error = rustls::Error;
48+
49+
fn try_from(value: &PrivateKeyDer<'_>) -> Result<Self, Self::Error> {
50+
match value {
51+
PrivateKeyDer::Pkcs8(der) => {
52+
let pkcs8: &[u8] = der.secret_pkcs8_der();
53+
let pkcs8_sz: word32 = pkcs8.len() as word32;
54+
let mut ret;
55+
let rsa_key_box = Box::new(unsafe { mem::zeroed::<RsaKey>() });
56+
let rsa_key_ptr = Box::into_raw(rsa_key_box);
57+
let rsa_key_object = unsafe { RsaKeyObject::from_ptr(rsa_key_ptr) };
58+
59+
ret = unsafe { wc_InitRsaKey(rsa_key_object.as_ptr(), ptr::null_mut()) };
60+
check_if_zero(ret).unwrap();
61+
62+
let mut idx: u32 = 0;
63+
64+
ret = unsafe {
65+
wc_RsaPrivateKeyDecode(
66+
pkcs8.as_ptr() as *mut u8,
67+
&mut idx,
68+
rsa_key_object.as_ptr(),
69+
pkcs8_sz,
70+
)
71+
};
72+
check_if_zero(ret)
73+
.map_err(|_| rustls::Error::General("FFI function failed".into()))?;
74+
75+
Ok(Self {
76+
key: Arc::new(rsa_key_object),
77+
algo: SignatureAlgorithm::RSA,
78+
})
79+
}
80+
PrivateKeyDer::Pkcs1(der) => {
81+
let pkcs1: &[u8] = der.secret_pkcs1_der();
82+
let pkcs1_sz: word32 = pkcs1.len() as word32;
83+
let mut ret;
84+
let rsa_key_box = Box::new(unsafe { mem::zeroed::<RsaKey>() });
85+
let rsa_key_ptr = Box::into_raw(rsa_key_box);
86+
let rsa_key_object = unsafe { RsaKeyObject::from_ptr(rsa_key_ptr) };
87+
88+
ret = unsafe { wc_InitRsaKey(rsa_key_object.as_ptr(), ptr::null_mut()) };
89+
check_if_zero(ret).unwrap();
90+
91+
let mut idx: u32 = 0;
92+
93+
ret = unsafe {
94+
wc_RsaPrivateKeyDecode(
95+
pkcs1.as_ptr() as *mut u8,
96+
&mut idx,
97+
rsa_key_object.as_ptr(),
98+
pkcs1_sz,
99+
)
100+
};
101+
check_if_zero(ret)
102+
.map_err(|_| rustls::Error::General("FFI function failed".into()))?;
103+
104+
Ok(Self {
105+
key: Arc::new(rsa_key_object),
106+
algo: SignatureAlgorithm::RSA,
107+
})
108+
}
109+
_ => Err(rustls::Error::General(
110+
"Unsupported private key format".into(),
111+
)),
112+
}
113+
}
114+
}
115+
116+
impl SigningKey for RsaPrivateKey {
117+
fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option<Box<dyn Signer>> {
118+
// Iterate through all RSA schemes and check if any is in the offered list
119+
ALL_RSA_SCHEMES.iter().find_map(|&scheme| {
120+
if offered.contains(&scheme) {
121+
Some(Box::new(RsaSigner {
122+
key: self.get_key(),
123+
scheme,
124+
}) as Box<dyn Signer>)
125+
} else {
126+
None
127+
}
128+
})
129+
}
130+
131+
fn algorithm(&self) -> SignatureAlgorithm {
132+
self.algo
133+
}
134+
}
135+
136+
#[derive(Clone, Debug)]
137+
pub struct RsaSigner {
138+
key: Arc<RsaKeyObject>,
139+
scheme: SignatureScheme,
140+
}
141+
142+
impl RsaSigner {
143+
pub fn new(key: Arc<RsaKeyObject>, scheme: SignatureScheme) -> Self {
144+
Self { key, scheme }
145+
}
146+
147+
fn get_key(&self) -> Arc<RsaKeyObject> {
148+
Arc::clone(&self.key)
149+
}
150+
}
151+
152+
impl Signer for RsaSigner {
153+
fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rustls::Error> {
154+
let rsa_key_arc = self.get_key();
155+
let rsa_key_object = rsa_key_arc.as_ref();
156+
157+
// Prepare a random generator
158+
let mut rng: WC_RNG = unsafe { mem::zeroed() };
159+
let rng_object = WCRngObject::new(&mut rng);
160+
rng_object.init();
161+
162+
// Allocate enough space for the signature
163+
let mut sig_buf = [0u8; MAX_RSA_SIG_SIZE];
164+
165+
match self.scheme {
166+
// ------------------------------------------------
167+
// RSA-PSS branch:
168+
// ------------------------------------------------
169+
SignatureScheme::RSA_PSS_SHA256
170+
| SignatureScheme::RSA_PSS_SHA384
171+
| SignatureScheme::RSA_PSS_SHA512 => {
172+
// We'll do explicit hashing plus wc_RsaPSS_Sign.
173+
174+
// 1) Determine hash algorithm & MGF
175+
let (hash_ty, mgf_ty, digest_len) = match self.scheme {
176+
SignatureScheme::RSA_PSS_SHA256 => {
177+
(HASH_TYPE_SHA256, MGF1_SHA256, WC_SHA256_DIGEST_SIZE)
178+
}
179+
SignatureScheme::RSA_PSS_SHA384 => {
180+
(HASH_TYPE_SHA384, MGF1_SHA384, WC_SHA384_DIGEST_SIZE)
181+
}
182+
SignatureScheme::RSA_PSS_SHA512 => {
183+
(HASH_TYPE_SHA512, MGF1_SHA512, WC_SHA512_DIGEST_SIZE)
184+
}
185+
_ => unreachable!(),
186+
};
187+
188+
// 2) Hash the message ourselves
189+
let mut digest = vec![0u8; digest_len as usize];
190+
let ret = unsafe {
191+
match hash_ty {
192+
HASH_TYPE_SHA256 => wc_Sha256Hash(
193+
message.as_ptr(),
194+
message.len() as u32,
195+
digest.as_mut_ptr(),
196+
),
197+
HASH_TYPE_SHA384 => wc_Sha384Hash(
198+
message.as_ptr(),
199+
message.len() as u32,
200+
digest.as_mut_ptr(),
201+
),
202+
HASH_TYPE_SHA512 => wc_Sha512Hash(
203+
message.as_ptr(),
204+
message.len() as u32,
205+
digest.as_mut_ptr(),
206+
),
207+
_ => -1,
208+
}
209+
};
210+
check_if_zero(ret)
211+
.map_err(|_| rustls::Error::General("Failed to hash for PSS".into()))?;
212+
213+
// 3) Sign with wc_RsaPSS_Sign
214+
let ret = unsafe {
215+
wc_RsaPSS_Sign(
216+
digest.as_ptr(),
217+
digest_len,
218+
sig_buf.as_mut_ptr(),
219+
sig_buf.len() as u32,
220+
hash_ty,
221+
mgf_ty.try_into().unwrap(),
222+
rsa_key_object.as_ptr(),
223+
rng_object.as_ptr(),
224+
)
225+
};
226+
check_if_greater_than_zero(ret)
227+
.map_err(|_| rustls::Error::General("wc_RsaPSS_Sign failed".into()))?;
228+
229+
let sig_len = ret as usize;
230+
let mut sig_vec = sig_buf.to_vec();
231+
sig_vec.truncate(sig_len);
232+
Ok(sig_vec)
233+
}
234+
235+
// ------------------------------------------------
236+
// RSA-PKCS#1 branch:
237+
// ------------------------------------------------
238+
SignatureScheme::RSA_PKCS1_SHA256
239+
| SignatureScheme::RSA_PKCS1_SHA384
240+
| SignatureScheme::RSA_PKCS1_SHA512 => {
241+
// We'll let wc_SignatureGenerate do the hashing & PKCS#1.
242+
let hash_ty = match self.scheme {
243+
SignatureScheme::RSA_PKCS1_SHA256 => HASH_TYPE_SHA256,
244+
SignatureScheme::RSA_PKCS1_SHA384 => HASH_TYPE_SHA384,
245+
SignatureScheme::RSA_PKCS1_SHA512 => HASH_TYPE_SHA512,
246+
_ => unreachable!(),
247+
};
248+
249+
let mut sig_len: u32 = sig_buf.len() as u32;
250+
251+
// wc_SignatureGenerate will produce a PKCS#1 signature, including hashing.
252+
let deref_rsa_key_c_type = unsafe { *(rsa_key_object.as_ptr()) };
253+
let ret = unsafe {
254+
wc_SignatureGenerate(
255+
hash_ty,
256+
wc_SignatureType_WC_SIGNATURE_TYPE_RSA_W_ENC,
257+
message.as_ptr(),
258+
message.len() as u32,
259+
sig_buf.as_mut_ptr(),
260+
&mut sig_len,
261+
rsa_key_object.as_ptr() as *const core::ffi::c_void,
262+
mem::size_of_val(&deref_rsa_key_c_type).try_into().unwrap(),
263+
rng_object.as_ptr(),
264+
)
265+
};
266+
check_if_zero(ret)
267+
.map_err(|_| rustls::Error::General("wc_SignatureGenerate failed".into()))?;
268+
269+
// Check how big the actual signature is
270+
let actual_sig_size = unsafe {
271+
wc_SignatureGetSize(
272+
wc_SignatureType_WC_SIGNATURE_TYPE_RSA_W_ENC,
273+
rsa_key_object.as_ptr() as *const core::ffi::c_void,
274+
mem::size_of_val(&deref_rsa_key_c_type).try_into().unwrap(),
275+
)
276+
};
277+
278+
let mut sig_vec = sig_buf.to_vec();
279+
// Truncate to the size returned by wc_SignatureGetSize or the updated `sig_len`.
280+
let min_len = core::cmp::min(actual_sig_size as usize, sig_len as usize);
281+
sig_vec.truncate(min_len);
282+
283+
Ok(sig_vec)
284+
}
285+
286+
// If someone tries a scheme that isn't RSA...
287+
_ => Err(rustls::Error::General("Unsupported RSA scheme".into())),
288+
}
289+
}
290+
291+
fn scheme(&self) -> SignatureScheme {
292+
self.scheme
293+
}
294+
}

0 commit comments

Comments
 (0)