Skip to content

Commit fa54752

Browse files
committed
clean up ssl
1 parent 9ddd43c commit fa54752

File tree

1 file changed

+68
-39
lines changed

1 file changed

+68
-39
lines changed

stdlib/src/ssl.rs

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ mod _ssl {
6868
ffi::CStr,
6969
fmt,
7070
io::{Read, Write},
71-
path::Path,
71+
path::{Path, PathBuf},
72+
sync::LazyLock,
7273
time::Instant,
7374
};
7475

@@ -193,7 +194,8 @@ mod _ssl {
193194

194195
#[pyattr(name = "_OPENSSL_API_VERSION")]
195196
fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo {
196-
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16).unwrap();
197+
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16)
198+
.expect("OPENSSL_API_VERSION is malformed");
197199
parse_version_info(openssl_api_version)
198200
}
199201

@@ -251,7 +253,8 @@ mod _ssl {
251253
/// SSL/TLS connection terminated abruptly.
252254
#[pyattr(name = "SSLEOFError", once)]
253255
fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef {
254-
PyType::new_simple_heap("ssl.SSLEOFError", &ssl_error(vm), &vm.ctx).unwrap()
256+
vm.ctx
257+
.new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)]))
255258
}
256259

257260
type OpensslVersionInfo = (u8, u8, u8, u8, u8);
@@ -352,14 +355,17 @@ mod _ssl {
352355
}
353356

354357
type PyNid = (libc::c_int, String, String, Option<String>);
355-
fn obj2py(obj: &Asn1ObjectRef) -> PyNid {
358+
fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult<PyNid> {
356359
let nid = obj.nid();
357-
(
358-
nid.as_raw(),
359-
nid.short_name().unwrap().to_owned(),
360-
nid.long_name().unwrap().to_owned(),
361-
obj2txt(obj, true),
362-
)
360+
let short_name = nid
361+
.short_name()
362+
.map_err(|_| vm.new_value_error("NID has no short name".to_owned()))?
363+
.to_owned();
364+
let long_name = nid
365+
.long_name()
366+
.map_err(|_| vm.new_value_error("NID has no long name".to_owned()))?
367+
.to_owned();
368+
Ok((nid.as_raw(), short_name, long_name, obj2txt(obj, true)))
363369
}
364370

365371
#[derive(FromArgs)]
@@ -373,55 +379,81 @@ mod _ssl {
373379
fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult<PyNid> {
374380
_txt2obj(&args.txt.to_cstring(vm)?, !args.name)
375381
.as_deref()
376-
.map(obj2py)
377382
.ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt)))
383+
.and_then(|obj| obj2py(obj, vm))
378384
}
379385

380386
#[pyfunction]
381387
fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult<PyNid> {
382388
_nid2obj(Nid::from_raw(nid))
383389
.as_deref()
384-
.map(obj2py)
385390
.ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}")))
391+
.and_then(|obj| obj2py(obj, vm))
386392
}
387393

388-
fn get_cert_file_dir() -> (&'static Path, &'static Path) {
389-
let probe = probe();
390-
// on windows, these should be utf8 strings
391-
fn path_from_bytes(c: &CStr) -> &Path {
394+
// Lazily compute and cache cert file/dir paths
395+
static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| {
396+
fn path_from_cstr(c: &CStr) -> PathBuf {
392397
#[cfg(unix)]
393398
{
394399
use std::os::unix::ffi::OsStrExt;
395-
std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref()
400+
std::ffi::OsStr::from_bytes(c.to_bytes()).into()
396401
}
397402
#[cfg(windows)]
398403
{
399-
c.to_str().unwrap().as_ref()
404+
// Use lossy conversion for potential non-UTF8
405+
PathBuf::from(c.to_string_lossy().as_ref())
400406
}
401407
}
402-
let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| {
403-
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
404-
});
405-
let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| {
406-
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
407-
});
408+
409+
let probe = probe();
410+
let cert_file = probe
411+
.cert_file
412+
.as_ref()
413+
.map(PathBuf::from)
414+
.unwrap_or_else(|| {
415+
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
416+
});
417+
let cert_dir = probe
418+
.cert_dir
419+
.as_ref()
420+
.map(PathBuf::from)
421+
.unwrap_or_else(|| {
422+
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
423+
});
408424
(cert_file, cert_dir)
425+
});
426+
427+
fn get_cert_file_dir() -> (&'static Path, &'static Path) {
428+
let (cert_file, cert_dir) = &*CERT_PATHS;
429+
(cert_file.as_path(), cert_dir.as_path())
409430
}
410431

432+
// Lazily compute and cache cert environment variable names
433+
static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| {
434+
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
435+
.to_string_lossy()
436+
.into_owned();
437+
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
438+
.to_string_lossy()
439+
.into_owned();
440+
(cert_file_env, cert_dir_env)
441+
});
442+
411443
#[pyfunction]
412444
fn get_default_verify_paths(
413445
vm: &VirtualMachine,
414446
) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> {
415-
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
416-
.to_str()
417-
.unwrap();
418-
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
419-
.to_str()
420-
.unwrap();
447+
let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES;
421448
let (cert_file, cert_dir) = get_cert_file_dir();
422449
let cert_file = OsPath::new_str(cert_file).filename(vm);
423450
let cert_dir = OsPath::new_str(cert_dir).filename(vm);
424-
Ok((cert_file_env, cert_file, cert_dir_env, cert_dir))
451+
Ok((
452+
cert_file_env.as_str(),
453+
cert_file,
454+
cert_dir_env.as_str(),
455+
cert_dir,
456+
))
425457
}
426458

427459
#[pyfunction(name = "RAND_status")]
@@ -1871,12 +1903,12 @@ mod _ssl {
18711903
}
18721904

18731905
#[pygetset]
1874-
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
1906+
fn id(&self, vm: &VirtualMachine) -> PyBytesRef {
18751907
unsafe {
18761908
let mut len: libc::c_uint = 0;
18771909
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
18781910
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
1879-
vm.ctx.new_bytes(id_slice.to_vec()).into()
1911+
vm.ctx.new_bytes(id_slice.to_vec())
18801912
}
18811913
}
18821914

@@ -2256,21 +2288,18 @@ mod windows {
22562288
Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")),
22572289
other => vm.new_pyobj(other),
22582290
};
2259-
let usage: PyObjectRef = match c.valid_uses()? {
2291+
let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? {
22602292
ValidUses::All => vm.ctx.new_bool(true).into(),
22612293
ValidUses::Oids(oids) => PyFrozenSet::from_iter(
22622294
vm,
22632295
oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()),
2264-
)
2265-
.unwrap()
2296+
)?
22662297
.into_ref(&vm.ctx)
22672298
.into(),
22682299
};
22692300
Ok(vm.new_tuple((cert, enc_type, usage)).into())
22702301
});
2271-
let certs = certs
2272-
.collect::<Result<Vec<_>, _>>()
2273-
.map_err(|e: std::io::Error| e.to_pyexception(vm))?;
2302+
let certs: Vec<PyObjectRef> = certs.collect::<PyResult<Vec<_>>>()?;
22742303
Ok(certs)
22752304
}
22762305
}

0 commit comments

Comments
 (0)