@@ -38,15 +38,16 @@ mod _ssl {
3838 } ,
3939 socket:: { self , PySocket } ,
4040 vm:: {
41- PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine ,
41+ Py , PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine ,
4242 builtins:: { PyBaseExceptionRef , PyStrRef , PyType , PyTypeRef , PyWeak } ,
43+ class_or_notimplemented,
4344 convert:: { ToPyException , ToPyObject } ,
4445 exceptions,
4546 function:: {
4647 ArgBytesLike , ArgCallable , ArgMemoryBuffer , ArgStrOrBytesLike , Either , FsPath ,
47- OptionalArg ,
48+ OptionalArg , PyComparisonValue ,
4849 } ,
49- types:: Constructor ,
50+ types:: { Comparable , Constructor , PyComparisonOp } ,
5051 utils:: ToCString ,
5152 } ,
5253 } ;
@@ -816,16 +817,22 @@ mod _ssl {
816817 let stream = ssl:: SslStream :: new ( ssl, SocketStream ( args. sock . clone ( ) ) )
817818 . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
818819
819- // TODO: use this
820- let _ = args. session ;
821-
822- Ok ( PySslSocket {
820+ let py_ssl_socket = PySslSocket {
823821 ctx : zelf,
824822 stream : PyRwLock :: new ( stream) ,
825823 socket_type,
826824 server_hostname : args. server_hostname ,
827825 owner : PyRwLock :: new ( args. owner . map ( |o| o. downgrade ( None , vm) ) . transpose ( ) ?) ,
828- } )
826+ } ;
827+
828+ // Set session if provided
829+ if let Some ( session) = args. session {
830+ if !vm. is_none ( & session) {
831+ py_ssl_socket. set_session ( session, vm) ?;
832+ }
833+ }
834+
835+ Ok ( py_ssl_socket)
829836 }
830837 }
831838
@@ -1103,6 +1110,73 @@ mod _ssl {
11031110 }
11041111 }
11051112
1113+ #[ pygetset]
1114+ fn session ( & self , _vm : & VirtualMachine ) -> PyResult < Option < PySslSession > > {
1115+ let stream = self . stream . read ( ) ;
1116+ unsafe {
1117+ let session_ptr = sys:: SSL_get_session ( stream. ssl ( ) . as_ptr ( ) ) ;
1118+ if session_ptr. is_null ( ) {
1119+ Ok ( None )
1120+ } else {
1121+ // Increment reference count since SSL_get_session returns a borrowed reference
1122+ #[ cfg( ossl110) ]
1123+ let _session = sys:: SSL_SESSION_up_ref ( session_ptr) ;
1124+
1125+ Ok ( Some ( PySslSession {
1126+ session : session_ptr,
1127+ ctx : self . ctx . clone ( ) ,
1128+ } ) )
1129+ }
1130+ }
1131+ }
1132+
1133+ #[ pygetset( setter) ]
1134+ fn set_session ( & self , value : PyObjectRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
1135+ // Check if value is SSLSession type
1136+ let session = value
1137+ . downcast_ref :: < PySslSession > ( )
1138+ . ok_or_else ( || vm. new_type_error ( "Value is not a SSLSession." . to_owned ( ) ) ) ?;
1139+
1140+ // Check if session refers to the same SSLContext
1141+ if !std:: ptr:: eq (
1142+ self . ctx . ctx . read ( ) . as_ptr ( ) ,
1143+ session. ctx . ctx . read ( ) . as_ptr ( ) ,
1144+ ) {
1145+ return Err (
1146+ vm. new_value_error ( "Session refers to a different SSLContext." . to_owned ( ) )
1147+ ) ;
1148+ }
1149+
1150+ // Check if this is a client socket
1151+ if self . socket_type != SslServerOrClient :: Client {
1152+ return Err (
1153+ vm. new_value_error ( "Cannot set session for server-side SSLSocket." . to_owned ( ) )
1154+ ) ;
1155+ }
1156+
1157+ // Check if handshake is not finished
1158+ let stream = self . stream . read ( ) ;
1159+ unsafe {
1160+ if sys:: SSL_is_init_finished ( stream. ssl ( ) . as_ptr ( ) ) != 0 {
1161+ return Err (
1162+ vm. new_value_error ( "Cannot set session after handshake." . to_owned ( ) )
1163+ ) ;
1164+ }
1165+
1166+ if sys:: SSL_set_session ( stream. ssl ( ) . as_ptr ( ) , session. session ) == 0 {
1167+ return Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ;
1168+ }
1169+ }
1170+
1171+ Ok ( ( ) )
1172+ }
1173+
1174+ #[ pygetset]
1175+ fn session_reused ( & self ) -> bool {
1176+ let stream = self . stream . read ( ) ;
1177+ unsafe { sys:: SSL_session_reused ( stream. ssl ( ) . as_ptr ( ) ) != 0 }
1178+ }
1179+
11061180 #[ pymethod]
11071181 fn read (
11081182 & self ,
@@ -1164,6 +1238,132 @@ mod _ssl {
11641238 }
11651239 }
11661240
1241+ #[ pyattr]
1242+ #[ pyclass( module = "ssl" , name = "SSLSession" ) ]
1243+ #[ derive( PyPayload ) ]
1244+ struct PySslSession {
1245+ session : * mut sys:: SSL_SESSION ,
1246+ ctx : PyRef < PySslContext > ,
1247+ }
1248+
1249+ impl fmt:: Debug for PySslSession {
1250+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
1251+ f. pad ( "SSLSession" )
1252+ }
1253+ }
1254+
1255+ impl Drop for PySslSession {
1256+ fn drop ( & mut self ) {
1257+ if !self . session . is_null ( ) {
1258+ unsafe {
1259+ sys:: SSL_SESSION_free ( self . session ) ;
1260+ }
1261+ }
1262+ }
1263+ }
1264+
1265+ unsafe impl Send for PySslSession { }
1266+ unsafe impl Sync for PySslSession { }
1267+
1268+ impl Comparable for PySslSession {
1269+ fn cmp (
1270+ zelf : & Py < Self > ,
1271+ other : & crate :: vm:: PyObject ,
1272+ op : PyComparisonOp ,
1273+ _vm : & VirtualMachine ,
1274+ ) -> PyResult < PyComparisonValue > {
1275+ let other = class_or_notimplemented ! ( Self , other) ;
1276+
1277+ if !matches ! ( op, PyComparisonOp :: Eq | PyComparisonOp :: Ne ) {
1278+ return Ok ( PyComparisonValue :: NotImplemented ) ;
1279+ }
1280+ let mut eq = unsafe {
1281+ let mut self_len: libc:: c_uint = 0 ;
1282+ let mut other_len: libc:: c_uint = 0 ;
1283+ let self_id = sys:: SSL_SESSION_get_id ( zelf. session , & mut self_len) ;
1284+ let other_id = sys:: SSL_SESSION_get_id ( other. session , & mut other_len) ;
1285+
1286+ if self_len != other_len {
1287+ false
1288+ } else {
1289+ let self_slice = std:: slice:: from_raw_parts ( self_id, self_len as usize ) ;
1290+ let other_slice = std:: slice:: from_raw_parts ( other_id, other_len as usize ) ;
1291+ self_slice == other_slice
1292+ }
1293+ } ;
1294+ if matches ! ( op, PyComparisonOp :: Ne ) {
1295+ eq = !eq;
1296+ }
1297+ Ok ( PyComparisonValue :: Implemented ( eq) )
1298+ }
1299+ }
1300+
1301+ #[ pyclass( with( Comparable ) ) ]
1302+ impl PySslSession {
1303+ #[ pygetset]
1304+ fn time ( & self ) -> i64 {
1305+ unsafe {
1306+ #[ cfg( ossl330) ]
1307+ {
1308+ sys:: SSL_SESSION_get_time ( self . session ) as i64
1309+ }
1310+ #[ cfg( not( ossl330) ) ]
1311+ {
1312+ sys:: SSL_SESSION_get_time ( self . session ) as i64
1313+ }
1314+ }
1315+ }
1316+
1317+ #[ pygetset]
1318+ fn timeout ( & self ) -> i64 {
1319+ unsafe { sys:: SSL_SESSION_get_timeout ( self . session ) as i64 }
1320+ }
1321+
1322+ #[ pygetset]
1323+ fn ticket_lifetime_hint ( & self ) -> u64 {
1324+ // SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL
1325+ // Return 0 as default if not available
1326+ #[ cfg( ossl110) ]
1327+ {
1328+ // For now, return 0 as this function may not be in openssl-sys
1329+ let _ = self . session ;
1330+ 0
1331+ }
1332+ #[ cfg( not( ossl110) ) ]
1333+ {
1334+ let _ = self . session ;
1335+ 0
1336+ }
1337+ }
1338+
1339+ #[ pygetset]
1340+ fn id ( & self , vm : & VirtualMachine ) -> PyObjectRef {
1341+ unsafe {
1342+ let mut len: libc:: c_uint = 0 ;
1343+ let id_ptr = sys:: SSL_SESSION_get_id ( self . session , & mut len) ;
1344+ let id_slice = std:: slice:: from_raw_parts ( id_ptr, len as usize ) ;
1345+ vm. ctx . new_bytes ( id_slice. to_vec ( ) ) . into ( )
1346+ }
1347+ }
1348+
1349+ #[ pygetset]
1350+ fn has_ticket ( & self ) -> bool {
1351+ // SSL_SESSION_has_ticket may not be available in older OpenSSL
1352+ // Return false as default
1353+ #[ cfg( ossl110) ]
1354+ {
1355+ // For now, return false as this function may not be in openssl-sys
1356+ let _ = self . session ;
1357+ false
1358+ }
1359+ #[ cfg( not( ossl110) ) ]
1360+ {
1361+ let _ = self . session ;
1362+ false
1363+ }
1364+ }
1365+ }
1366+
11671367 #[ track_caller]
11681368 fn convert_openssl_error ( vm : & VirtualMachine , err : ErrorStack ) -> PyBaseExceptionRef {
11691369 let cls = ssl_error ( vm) ;
0 commit comments