Skip to content

Commit 2dbe343

Browse files
committed
fix set_ecdh_curve
1 parent 69cedf6 commit 2dbe343

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

stdlib/src/ssl.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,29 @@ mod _ssl {
679679
}
680680

681681
#[pymethod]
682-
fn set_ecdh_curve(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
682+
fn set_ecdh_curve(
683+
&self,
684+
name: Either<PyStrRef, ArgBytesLike>,
685+
vm: &VirtualMachine,
686+
) -> PyResult<()> {
683687
use openssl::ec::{EcGroup, EcKey};
684688

685-
let curve_name = name.as_str();
686-
if curve_name.contains('\0') {
687-
return Err(exceptions::cstring_error(vm));
688-
}
689+
// Convert name to CString, supporting both str and bytes
690+
let name_cstr = match name {
691+
Either::A(s) => {
692+
if s.as_str().contains('\0') {
693+
return Err(exceptions::cstring_error(vm));
694+
}
695+
s.to_cstring(vm)?
696+
}
697+
Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec())
698+
.map_err(|_| exceptions::cstring_error(vm))?,
699+
};
689700

690701
// Find the NID for the curve name using OBJ_sn2nid
691-
let name_cstr = name.to_cstring(vm)?;
692702
let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) };
693703
if nid_raw == 0 {
694-
return Err(vm.new_value_error(format!("unknown curve name: {}", curve_name)));
704+
return Err(vm.new_value_error("unknown curve name"));
695705
}
696706
let nid = Nid::from_raw(nid_raw);
697707

0 commit comments

Comments
 (0)