Skip to content

Commit 4bfd8c6

Browse files
committed
SSLSession
1 parent 153d0ee commit 4bfd8c6

File tree

1 file changed

+208
-8
lines changed

1 file changed

+208
-8
lines changed

stdlib/src/ssl.rs

Lines changed: 208 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,16 @@ mod _ssl {
3838
},
3939
socket::{self, PySocket},
4040
vm::{
41-
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
41+
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
4242
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
43+
class_or_notimplemented,
4344
convert::{ToPyException, ToPyObject},
4445
exceptions,
4546
function::{
4647
ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath,
47-
OptionalArg,
48+
OptionalArg, PyComparisonValue,
4849
},
49-
types::Constructor,
50+
types::{Comparable, Constructor, PyComparisonOp},
5051
utils::ToCString,
5152
},
5253
};
@@ -816,16 +817,22 @@ mod _ssl {
816817
let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone()))
817818
.map_err(|e| convert_openssl_error(vm, e))?;
818819

819-
// TODO: use this
820-
let _ = args.session;
821-
822-
Ok(PySslSocket {
820+
let py_ssl_socket = PySslSocket {
823821
ctx: zelf,
824822
stream: PyRwLock::new(stream),
825823
socket_type,
826824
server_hostname: args.server_hostname,
827825
owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?),
828-
})
826+
};
827+
828+
// Set session if provided
829+
if let Some(session) = args.session {
830+
if !vm.is_none(&session) {
831+
py_ssl_socket.set_session(session, vm)?;
832+
}
833+
}
834+
835+
Ok(py_ssl_socket)
829836
}
830837
}
831838

@@ -1103,6 +1110,73 @@ mod _ssl {
11031110
}
11041111
}
11051112

1113+
#[pygetset]
1114+
fn session(&self, _vm: &VirtualMachine) -> PyResult<Option<PySslSession>> {
1115+
let stream = self.stream.read();
1116+
unsafe {
1117+
let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr());
1118+
if session_ptr.is_null() {
1119+
Ok(None)
1120+
} else {
1121+
// Increment reference count since SSL_get_session returns a borrowed reference
1122+
#[cfg(ossl110)]
1123+
let _session = sys::SSL_SESSION_up_ref(session_ptr);
1124+
1125+
Ok(Some(PySslSession {
1126+
session: session_ptr,
1127+
ctx: self.ctx.clone(),
1128+
}))
1129+
}
1130+
}
1131+
}
1132+
1133+
#[pygetset(setter)]
1134+
fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
1135+
// Check if value is SSLSession type
1136+
let session = value
1137+
.downcast_ref::<PySslSession>()
1138+
.ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?;
1139+
1140+
// Check if session refers to the same SSLContext
1141+
if !std::ptr::eq(
1142+
self.ctx.ctx.read().as_ptr(),
1143+
session.ctx.ctx.read().as_ptr(),
1144+
) {
1145+
return Err(
1146+
vm.new_value_error("Session refers to a different SSLContext.".to_owned())
1147+
);
1148+
}
1149+
1150+
// Check if this is a client socket
1151+
if self.socket_type != SslServerOrClient::Client {
1152+
return Err(
1153+
vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned())
1154+
);
1155+
}
1156+
1157+
// Check if handshake is not finished
1158+
let stream = self.stream.read();
1159+
unsafe {
1160+
if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 {
1161+
return Err(
1162+
vm.new_value_error("Cannot set session after handshake.".to_owned())
1163+
);
1164+
}
1165+
1166+
if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 {
1167+
return Err(convert_openssl_error(vm, ErrorStack::get()));
1168+
}
1169+
}
1170+
1171+
Ok(())
1172+
}
1173+
1174+
#[pygetset]
1175+
fn session_reused(&self) -> bool {
1176+
let stream = self.stream.read();
1177+
unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 }
1178+
}
1179+
11061180
#[pymethod]
11071181
fn read(
11081182
&self,
@@ -1164,6 +1238,132 @@ mod _ssl {
11641238
}
11651239
}
11661240

1241+
#[pyattr]
1242+
#[pyclass(module = "ssl", name = "SSLSession")]
1243+
#[derive(PyPayload)]
1244+
struct PySslSession {
1245+
session: *mut sys::SSL_SESSION,
1246+
ctx: PyRef<PySslContext>,
1247+
}
1248+
1249+
impl fmt::Debug for PySslSession {
1250+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1251+
f.pad("SSLSession")
1252+
}
1253+
}
1254+
1255+
impl Drop for PySslSession {
1256+
fn drop(&mut self) {
1257+
if !self.session.is_null() {
1258+
unsafe {
1259+
sys::SSL_SESSION_free(self.session);
1260+
}
1261+
}
1262+
}
1263+
}
1264+
1265+
unsafe impl Send for PySslSession {}
1266+
unsafe impl Sync for PySslSession {}
1267+
1268+
impl Comparable for PySslSession {
1269+
fn cmp(
1270+
zelf: &Py<Self>,
1271+
other: &crate::vm::PyObject,
1272+
op: PyComparisonOp,
1273+
_vm: &VirtualMachine,
1274+
) -> PyResult<PyComparisonValue> {
1275+
let other = class_or_notimplemented!(Self, other);
1276+
1277+
if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) {
1278+
return Ok(PyComparisonValue::NotImplemented);
1279+
}
1280+
let mut eq = unsafe {
1281+
let mut self_len: libc::c_uint = 0;
1282+
let mut other_len: libc::c_uint = 0;
1283+
let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len);
1284+
let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len);
1285+
1286+
if self_len != other_len {
1287+
false
1288+
} else {
1289+
let self_slice = std::slice::from_raw_parts(self_id, self_len as usize);
1290+
let other_slice = std::slice::from_raw_parts(other_id, other_len as usize);
1291+
self_slice == other_slice
1292+
}
1293+
};
1294+
if matches!(op, PyComparisonOp::Ne) {
1295+
eq = !eq;
1296+
}
1297+
Ok(PyComparisonValue::Implemented(eq))
1298+
}
1299+
}
1300+
1301+
#[pyclass(with(Comparable))]
1302+
impl PySslSession {
1303+
#[pygetset]
1304+
fn time(&self) -> i64 {
1305+
unsafe {
1306+
#[cfg(ossl330)]
1307+
{
1308+
sys::SSL_SESSION_get_time(self.session) as i64
1309+
}
1310+
#[cfg(not(ossl330))]
1311+
{
1312+
sys::SSL_SESSION_get_time(self.session) as i64
1313+
}
1314+
}
1315+
}
1316+
1317+
#[pygetset]
1318+
fn timeout(&self) -> i64 {
1319+
unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 }
1320+
}
1321+
1322+
#[pygetset]
1323+
fn ticket_lifetime_hint(&self) -> u64 {
1324+
// SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL
1325+
// Return 0 as default if not available
1326+
#[cfg(ossl110)]
1327+
{
1328+
// For now, return 0 as this function may not be in openssl-sys
1329+
let _ = self.session;
1330+
0
1331+
}
1332+
#[cfg(not(ossl110))]
1333+
{
1334+
let _ = self.session;
1335+
0
1336+
}
1337+
}
1338+
1339+
#[pygetset]
1340+
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
1341+
unsafe {
1342+
let mut len: libc::c_uint = 0;
1343+
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
1344+
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
1345+
vm.ctx.new_bytes(id_slice.to_vec()).into()
1346+
}
1347+
}
1348+
1349+
#[pygetset]
1350+
fn has_ticket(&self) -> bool {
1351+
// SSL_SESSION_has_ticket may not be available in older OpenSSL
1352+
// Return false as default
1353+
#[cfg(ossl110)]
1354+
{
1355+
// For now, return false as this function may not be in openssl-sys
1356+
let _ = self.session;
1357+
false
1358+
}
1359+
#[cfg(not(ossl110))]
1360+
{
1361+
let _ = self.session;
1362+
false
1363+
}
1364+
}
1365+
}
1366+
11671367
#[track_caller]
11681368
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
11691369
let cls = ssl_error(vm);

0 commit comments

Comments
 (0)