Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_httplib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,8 +1780,6 @@ def test_networked_bad_cert(self):
h.request('GET', '/')
self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED')

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skipIf(sys.platform == 'darwin', 'Occasionally success on macOS')
def test_local_unknown_cert(self):
# The custom cert isn't known to the default trust bundle
Expand Down
185 changes: 133 additions & 52 deletions stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ mod _ssl {
socket::{self, PySocket},
vm::{
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{
PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, PyWeak,
},
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak},
class_or_notimplemented,
convert::{ToPyException, ToPyObject},
exceptions,
Expand All @@ -68,7 +66,8 @@ mod _ssl {
ffi::CStr,
fmt,
io::{Read, Write},
path::Path,
path::{Path, PathBuf},
sync::LazyLock,
time::Instant,
};

Expand All @@ -91,6 +90,7 @@ mod _ssl {
// X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT,
SSL_ERROR_ZERO_RETURN,
SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE,
SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT,
SSL_OP_LEGACY_SERVER_CONNECT as OP_LEGACY_SERVER_CONNECT,
SSL_OP_NO_SSLv2 as OP_NO_SSLv2,
SSL_OP_NO_SSLv3 as OP_NO_SSLv3,
Expand Down Expand Up @@ -193,7 +193,8 @@ mod _ssl {

#[pyattr(name = "_OPENSSL_API_VERSION")]
fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo {
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16).unwrap();
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16)
.expect("OPENSSL_API_VERSION is malformed");
parse_version_info(openssl_api_version)
}

Expand Down Expand Up @@ -251,7 +252,8 @@ mod _ssl {
/// SSL/TLS connection terminated abruptly.
#[pyattr(name = "SSLEOFError", once)]
fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef {
PyType::new_simple_heap("ssl.SSLEOFError", &ssl_error(vm), &vm.ctx).unwrap()
vm.ctx
.new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)]))
}

type OpensslVersionInfo = (u8, u8, u8, u8, u8);
Expand Down Expand Up @@ -352,14 +354,17 @@ mod _ssl {
}

type PyNid = (libc::c_int, String, String, Option<String>);
fn obj2py(obj: &Asn1ObjectRef) -> PyNid {
fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult<PyNid> {
let nid = obj.nid();
(
nid.as_raw(),
nid.short_name().unwrap().to_owned(),
nid.long_name().unwrap().to_owned(),
obj2txt(obj, true),
)
let short_name = nid
.short_name()
.map_err(|_| vm.new_value_error("NID has no short name".to_owned()))?
.to_owned();
let long_name = nid
.long_name()
.map_err(|_| vm.new_value_error("NID has no long name".to_owned()))?
.to_owned();
Ok((nid.as_raw(), short_name, long_name, obj2txt(obj, true)))
}

#[derive(FromArgs)]
Expand All @@ -373,55 +378,81 @@ mod _ssl {
fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult<PyNid> {
_txt2obj(&args.txt.to_cstring(vm)?, !args.name)
.as_deref()
.map(obj2py)
.ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt)))
.and_then(|obj| obj2py(obj, vm))
}

#[pyfunction]
fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult<PyNid> {
_nid2obj(Nid::from_raw(nid))
.as_deref()
.map(obj2py)
.ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}")))
.and_then(|obj| obj2py(obj, vm))
}

fn get_cert_file_dir() -> (&'static Path, &'static Path) {
let probe = probe();
// on windows, these should be utf8 strings
fn path_from_bytes(c: &CStr) -> &Path {
// Lazily compute and cache cert file/dir paths
static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| {
fn path_from_cstr(c: &CStr) -> PathBuf {
#[cfg(unix)]
{
use std::os::unix::ffi::OsStrExt;
std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref()
std::ffi::OsStr::from_bytes(c.to_bytes()).into()
}
#[cfg(windows)]
{
c.to_str().unwrap().as_ref()
// Use lossy conversion for potential non-UTF8
PathBuf::from(c.to_string_lossy().as_ref())
}
}
let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| {
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
});
let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| {
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
});

let probe = probe();
let cert_file = probe
.cert_file
.as_ref()
.map(PathBuf::from)
.unwrap_or_else(|| {
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
});
let cert_dir = probe
.cert_dir
.as_ref()
.map(PathBuf::from)
.unwrap_or_else(|| {
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
});
(cert_file, cert_dir)
});

fn get_cert_file_dir() -> (&'static Path, &'static Path) {
let (cert_file, cert_dir) = &*CERT_PATHS;
(cert_file.as_path(), cert_dir.as_path())
}

// Lazily compute and cache cert environment variable names
static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| {
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
.to_string_lossy()
.into_owned();
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
.to_string_lossy()
.into_owned();
(cert_file_env, cert_dir_env)
});

#[pyfunction]
fn get_default_verify_paths(
vm: &VirtualMachine,
) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> {
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
.to_str()
.unwrap();
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
.to_str()
.unwrap();
let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES;
let (cert_file, cert_dir) = get_cert_file_dir();
let cert_file = OsPath::new_str(cert_file).filename(vm);
let cert_dir = OsPath::new_str(cert_dir).filename(vm);
Ok((cert_file_env, cert_file, cert_dir_env, cert_dir))
Ok((
cert_file_env.as_str(),
cert_file,
cert_dir_env.as_str(),
cert_dir,
))
}

#[pyfunction(name = "RAND_status")]
Expand Down Expand Up @@ -522,6 +553,7 @@ mod _ssl {
options |= SslOptions::CIPHER_SERVER_PREFERENCE;
options |= SslOptions::SINGLE_DH_USE;
options |= SslOptions::SINGLE_ECDH_USE;
options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT;
builder.set_options(options);

let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY;
Expand All @@ -536,6 +568,13 @@ mod _ssl {
.set_session_id_context(b"Python")
.map_err(|e| convert_openssl_error(vm, e))?;

// Set default verify flags: VERIFY_X509_TRUSTED_FIRST
unsafe {
let ctx_ptr = builder.as_ptr();
let param = sys::SSL_CTX_get0_param(ctx_ptr);
sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST);
}

PySslContext {
ctx: PyRwLock::new(builder),
check_hostname: AtomicCell::new(check_hostname),
Expand Down Expand Up @@ -846,8 +885,16 @@ mod _ssl {
let certs = ctx.cert_store().all_certificates();
#[cfg(not(ossl300))]
let certs = ctx.cert_store().objects().iter().filter_map(|x| x.x509());

// Filter to only include CA certificates (Basic Constraints: CA=TRUE)
let certs = certs
.into_iter()
.filter(|cert| {
unsafe {
// X509_check_ca() returns 1 for CA certificates
X509_check_ca(cert.as_ptr()) == 1
}
})
.map(|ref cert| cert_to_py(vm, cert, binary_form))
.collect::<Result<Vec<_>, _>>()?;
Ok(certs)
Expand Down Expand Up @@ -884,6 +931,20 @@ mod _ssl {
args: WrapSocketArgs,
vm: &VirtualMachine,
) -> PyResult<PySslSocket> {
// validate socket type and context protocol
if !args.server_side && zelf.protocol == SslVersion::TlsServer {
return Err(vm.new_exception_msg(
ssl_error(vm),
"Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(),
));
}
if args.server_side && zelf.protocol == SslVersion::TlsClient {
return Err(vm.new_exception_msg(
ssl_error(vm),
"Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(),
));
}

let mut ssl = ssl::Ssl::new(&zelf.ctx()).map_err(|e| convert_openssl_error(vm, e))?;

let socket_type = if args.server_side {
Expand Down Expand Up @@ -1681,6 +1742,12 @@ mod _ssl {
unsafe impl Sync for PySslMemoryBio {}

// OpenSSL functions not in openssl-sys

unsafe extern "C" {
// X509_check_ca returns 1 for CA certificates, 0 otherwise
fn X509_check_ca(x: *const sys::X509) -> libc::c_int;
}

unsafe extern "C" {
fn SSL_get_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER;
}
Expand Down Expand Up @@ -1857,12 +1924,12 @@ mod _ssl {
}

#[pygetset]
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
fn id(&self, vm: &VirtualMachine) -> PyBytesRef {
unsafe {
let mut len: libc::c_uint = 0;
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
vm.ctx.new_bytes(id_slice.to_vec()).into()
vm.ctx.new_bytes(id_slice.to_vec())
}
}

Expand Down Expand Up @@ -1900,23 +1967,39 @@ mod _ssl {
"certificate verify failed" => "CERTIFICATE_VERIFY_FAILED",
_ => default_errstr,
};
let msg = if let Some(lib) = e.library() {
// add `library` attribute
let attr_name = vm.ctx.as_ref().intern_str("library");
cls.set_attr(attr_name, vm.ctx.new_str(lib).into());

// Build message
let lib_obj = e.library();
let msg = if let Some(lib) = lib_obj {
format!("[{lib}] {errstr} ({file}:{line})")
} else {
format!("{errstr} ({file}:{line})")
};
// add `reason` attribute
let attr_name = vm.ctx.as_ref().intern_str("reason");
cls.set_attr(attr_name, vm.ctx.new_str(errstr).into());

// Create exception instance
let reason = sys::ERR_GET_REASON(e.code());
vm.new_exception(
let exc = vm.new_exception(
cls,
vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()],
)
);

// Set attributes on instance, not class
let exc_obj: PyObjectRef = exc.into();

// Set reason attribute (always set, even if just the error string)
let reason_value = vm.ctx.new_str(errstr);
let _ = exc_obj.set_attr("reason", reason_value, vm);

// Set library attribute (None if not available)
let library_value: PyObjectRef = if let Some(lib) = lib_obj {
vm.ctx.new_str(lib).into()
} else {
vm.ctx.none()
};
let _ = exc_obj.set_attr("library", library_value, vm);

// Convert back to PyBaseExceptionRef
exc_obj.downcast().unwrap()
}
None => vm.new_exception_empty(cls),
}
Expand Down Expand Up @@ -2013,7 +2096,8 @@ mod _ssl {

dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?;
dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?;
dict.set_item("version", vm.new_pyobj(cert.version()), vm)?;
// X.509 version: OpenSSL uses 0-based (0=v1, 1=v2, 2=v3) but Python uses 1-based (1=v1, 2=v2, 3=v3)
dict.set_item("version", vm.new_pyobj(cert.version() + 1), vm)?;

let serial_num = cert
.serial_number()
Expand Down Expand Up @@ -2226,21 +2310,18 @@ mod windows {
Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")),
other => vm.new_pyobj(other),
};
let usage: PyObjectRef = match c.valid_uses()? {
let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? {
ValidUses::All => vm.ctx.new_bool(true).into(),
ValidUses::Oids(oids) => PyFrozenSet::from_iter(
vm,
oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()),
)
.unwrap()
)?
.into_ref(&vm.ctx)
.into(),
};
Ok(vm.new_tuple((cert, enc_type, usage)).into())
});
let certs = certs
.collect::<Result<Vec<_>, _>>()
.map_err(|e: std::io::Error| e.to_pyexception(vm))?;
let certs: Vec<PyObjectRef> = certs.collect::<PyResult<Vec<_>>>()?;
Ok(certs)
}
}
Expand Down
Loading