Skip to content

Commit db95946

Browse files
authored
Fix SSL deferred error (RustPython#6371)
* Fix SSL to return deferred error on the right time * lease conn_guard * SslError::Io * is_connection_closed
1 parent a99164f commit db95946

File tree

3 files changed

+170
-134
lines changed

3 files changed

+170
-134
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 117 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -2598,24 +2598,26 @@ mod _ssl {
25982598
fn complete_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
25992599
*self.handshake_done.lock() = true;
26002600

2601-
// Check if session was resumed before creating session object
2602-
let conn_guard = self.connection.lock();
2603-
if let Some(ref conn) = *conn_guard {
2604-
let was_resumed = conn.is_session_resumed();
2605-
*self.session_was_reused.lock() = was_resumed;
2601+
// Check if session was resumed - get value and release lock immediately
2602+
let was_resumed = self
2603+
.connection
2604+
.lock()
2605+
.as_ref()
2606+
.map(|conn| conn.is_session_resumed())
2607+
.unwrap_or(false);
26062608

2607-
// Update context session statistics if server-side
2608-
if self.server_side {
2609-
let context = self.context.read();
2610-
// Increment accept count for every successful server handshake
2611-
context.accept_count.fetch_add(1, Ordering::SeqCst);
2612-
// Increment hits count if session was resumed
2613-
if was_resumed {
2614-
context.session_hits.fetch_add(1, Ordering::SeqCst);
2615-
}
2609+
*self.session_was_reused.lock() = was_resumed;
2610+
2611+
// Update context session statistics if server-side
2612+
if self.server_side {
2613+
let context = self.context.read();
2614+
// Increment accept count for every successful server handshake
2615+
context.accept_count.fetch_add(1, Ordering::SeqCst);
2616+
// Increment hits count if session was resumed
2617+
if was_resumed {
2618+
context.session_hits.fetch_add(1, Ordering::SeqCst);
26162619
}
26172620
}
2618-
drop(conn_guard);
26192621

26202622
// Track CA certificate used during handshake (client-side only)
26212623
// This simulates lazy loading behavior for capath certificates
@@ -3209,62 +3211,46 @@ mod _ssl {
32093211
}
32103212

32113213
// Perform the actual handshake by exchanging data with the socket/BIO
3212-
match conn_guard.as_mut() {
3213-
Some(TlsConnection::Client(_conn)) => {
3214-
// CLIENT is simple - no SNI callback handling needed
3215-
ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm)
3216-
.map_err(|e| e.into_py_err(vm))?;
32173214

3218-
drop(conn_guard);
3219-
self.complete_handshake(vm)?;
3220-
Ok(())
3221-
}
3222-
Some(TlsConnection::Server(_conn)) => {
3223-
// Use OpenSSL-compatible handshake for server
3224-
// Handle SNI callback restart
3225-
match ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm) {
3226-
Ok(()) => {
3227-
// Handshake completed successfully
3228-
drop(conn_guard);
3229-
self.complete_handshake(vm)?;
3230-
Ok(())
3231-
}
3232-
Err(SslError::SniCallbackRestart) => {
3233-
// SNI detected - need to call callback and recreate connection
3234-
3235-
// CRITICAL: Drop connection lock BEFORE calling Python callback to avoid deadlock
3236-
//
3237-
// Deadlock scenario if we keep the lock:
3238-
// 1. This thread holds self.connection.lock()
3239-
// 2. Python callback invokes other SSL methods (e.g., getpeercert(), cipher())
3240-
// 3. Those methods try to acquire self.connection.lock() again
3241-
// 4. PyMutex (parking_lot::Mutex) is not reentrant -> DEADLOCK
3242-
//
3243-
// Trade-off: By dropping the lock, we lose the ability to send TLS alerts
3244-
// because Rustls doesn't provide a send_fatal_alert() API. See detailed
3245-
// explanation in invoke_sni_callback() where we set _reason attribute.
3246-
drop(conn_guard);
3247-
3248-
// Get the SNI name that was extracted (may be None if client didn't send SNI)
3249-
let sni_name = self.get_extracted_sni_name();
3250-
3251-
// Now safe to call Python callback (no locks held)
3252-
self.invoke_sni_callback(sni_name.as_deref(), vm)?;
3253-
3254-
// Clear connection to trigger recreation
3255-
*self.connection.lock() = None;
3256-
3257-
// Recursively call do_handshake to recreate with new context
3258-
self.do_handshake(vm)
3259-
}
3260-
Err(e) => {
3261-
// Other errors - convert to Python exception
3262-
drop(conn_guard);
3263-
Err(e.into_py_err(vm))
3264-
}
3215+
let conn = conn_guard.as_mut().expect("unreachable");
3216+
let is_client = matches!(conn, TlsConnection::Client(_));
3217+
let handshake_result = ssl_do_handshake(conn, self, vm);
3218+
drop(conn_guard);
3219+
3220+
if is_client {
3221+
// CLIENT is simple - no SNI callback handling needed
3222+
handshake_result.map_err(|e| e.into_py_err(vm))?;
3223+
self.complete_handshake(vm)?;
3224+
Ok(())
3225+
} else {
3226+
// Use OpenSSL-compatible handshake for server
3227+
// Handle SNI callback restart
3228+
match handshake_result {
3229+
Ok(()) => {
3230+
// Handshake completed successfully
3231+
self.complete_handshake(vm)?;
3232+
Ok(())
3233+
}
3234+
Err(SslError::SniCallbackRestart) => {
3235+
// SNI detected - need to call callback and recreate connection
3236+
3237+
// Get the SNI name that was extracted (may be None if client didn't send SNI)
3238+
let sni_name = self.get_extracted_sni_name();
3239+
3240+
// Now safe to call Python callback (no locks held)
3241+
self.invoke_sni_callback(sni_name.as_deref(), vm)?;
3242+
3243+
// Clear connection to trigger recreation
3244+
*self.connection.lock() = None;
3245+
3246+
// Recursively call do_handshake to recreate with new context
3247+
self.do_handshake(vm)
3248+
}
3249+
Err(e) => {
3250+
// Other errors - convert to Python exception
3251+
Err(e.into_py_err(vm))
32653252
}
32663253
}
3267-
None => unreachable!(),
32683254
}
32693255
}
32703256

@@ -3323,9 +3309,6 @@ mod _ssl {
33233309
));
33243310
}
33253311

3326-
// Check for deferred certificate verification errors (TLS 1.3)
3327-
self.check_deferred_cert_error(vm)?;
3328-
33293312
// Helper function to handle return value based on buffer presence
33303313
let return_data = |data: Vec<u8>,
33313314
buffer_arg: &OptionalArg<ArgMemoryBuffer>,
@@ -3350,17 +3333,21 @@ mod _ssl {
33503333
}
33513334
};
33523335

3353-
let mut conn_guard = self.connection.lock();
3354-
let conn = conn_guard
3355-
.as_mut()
3356-
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3357-
33583336
// Use compat layer for unified read logic with proper EOF handling
33593337
// This matches CPython's SSL_read_ex() approach
33603338
let mut buf = vec![0u8; len];
3361-
3362-
match crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) {
3339+
let read_result = {
3340+
let mut conn_guard = self.connection.lock();
3341+
let conn = conn_guard
3342+
.as_mut()
3343+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3344+
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
3345+
};
3346+
match read_result {
33633347
Ok(n) => {
3348+
// Check for deferred certificate verification errors (TLS 1.3)
3349+
// Must be checked AFTER ssl_read, as the error is set during I/O
3350+
self.check_deferred_cert_error(vm)?;
33643351
buf.truncate(n);
33653352
return_data(buf, &buffer, vm)
33663353
}
@@ -3445,62 +3432,62 @@ mod _ssl {
34453432
));
34463433
}
34473434

3448-
// Check for deferred certificate verification errors (TLS 1.3)
3449-
self.check_deferred_cert_error(vm)?;
3450-
3451-
let mut conn_guard = self.connection.lock();
3452-
let conn = conn_guard
3453-
.as_mut()
3454-
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3435+
{
3436+
let mut conn_guard = self.connection.lock();
3437+
let conn = conn_guard
3438+
.as_mut()
3439+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
34553440

3456-
let is_bio = self.is_bio_mode();
3457-
let data: &[u8] = data_bytes.as_ref();
3441+
let is_bio = self.is_bio_mode();
3442+
let data: &[u8] = data_bytes.as_ref();
34583443

3459-
// Write data in chunks to avoid filling the internal TLS buffer
3460-
// rustls has a limited internal buffer, so we need to flush periodically
3461-
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
3462-
let mut written = 0;
3444+
// Write data in chunks to avoid filling the internal TLS buffer
3445+
// rustls has a limited internal buffer, so we need to flush periodically
3446+
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
3447+
let mut written = 0;
34633448

3464-
while written < data.len() {
3465-
let chunk_end = std::cmp::min(written + CHUNK_SIZE, data.len());
3466-
let chunk = &data[written..chunk_end];
3449+
while written < data.len() {
3450+
let chunk_end = std::cmp::min(written + CHUNK_SIZE, data.len());
3451+
let chunk = &data[written..chunk_end];
34673452

3468-
// Write chunk to TLS layer
3469-
{
3470-
let mut writer = conn.writer();
3471-
use std::io::Write;
3472-
writer
3473-
.write_all(chunk)
3474-
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
3475-
}
3453+
// Write chunk to TLS layer
3454+
{
3455+
let mut writer = conn.writer();
3456+
use std::io::Write;
3457+
writer
3458+
.write_all(chunk)
3459+
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
3460+
}
34763461

3477-
written = chunk_end;
3462+
written = chunk_end;
34783463

3479-
// Flush TLS data to socket after each chunk
3480-
if conn.wants_write() {
3481-
if is_bio {
3482-
self.write_pending_tls(conn, vm)?;
3483-
} else {
3484-
// Socket mode: flush all pending TLS data
3485-
while conn.wants_write() {
3486-
let mut buf = Vec::new();
3487-
conn.write_tls(&mut buf)
3488-
.map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?;
3489-
3490-
if !buf.is_empty() {
3491-
let timed_out =
3492-
self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
3493-
if timed_out {
3494-
return Err(vm.new_os_error("Write operation timed out"));
3495-
}
3464+
// Flush TLS data to socket after each chunk
3465+
if conn.wants_write() {
3466+
if is_bio {
3467+
self.write_pending_tls(conn, vm)?;
3468+
} else {
3469+
// Socket mode: flush all pending TLS data
3470+
while conn.wants_write() {
3471+
let mut buf = Vec::new();
3472+
conn.write_tls(&mut buf).map_err(|e| {
3473+
vm.new_os_error(format!("TLS write failed: {e}"))
3474+
})?;
3475+
3476+
if !buf.is_empty() {
3477+
let timed_out =
3478+
self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
3479+
if timed_out {
3480+
return Err(vm.new_os_error("Write operation timed out"));
3481+
}
34963482

3497-
match self.sock_send(buf, vm) {
3498-
Ok(_) => {}
3499-
Err(e) => {
3500-
if is_blocking_io_error(&e, vm) {
3501-
return Err(create_ssl_want_write_error(vm));
3483+
match self.sock_send(buf, vm) {
3484+
Ok(_) => {}
3485+
Err(e) => {
3486+
if is_blocking_io_error(&e, vm) {
3487+
return Err(create_ssl_want_write_error(vm));
3488+
}
3489+
return Err(e);
35023490
}
3503-
return Err(e);
35043491
}
35053492
}
35063493
}
@@ -3509,6 +3496,10 @@ mod _ssl {
35093496
}
35103497
}
35113498

3499+
// Check for deferred certificate verification errors (TLS 1.3)
3500+
// Must be checked AFTER write completes, as the error may be set during I/O
3501+
self.check_deferred_cert_error(vm)?;
3502+
35123503
Ok(data_len)
35133504
}
35143505

crates/stdlib/src/ssl/cert.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,15 +1067,16 @@ impl ClientCertVerifier for DeferredClientCertVerifier {
10671067
.inner
10681068
.verify_client_cert(end_entity, intermediates, now);
10691069

1070-
// If verification failed, store the error for later
1071-
if result.is_err() {
1072-
let error_msg = "TLS handshake failed: received fatal alert: UnknownCA".to_string();
1070+
// If verification failed, store the error for the server's Python code
1071+
// AND return the error so rustls sends the appropriate TLS alert
1072+
if let Err(ref e) = result {
1073+
let error_msg = format!("certificate verify failed: {e}");
10731074
*self.deferred_error.write() = Some(error_msg);
1075+
// Return the error to rustls so it sends the alert to the client
1076+
return result;
10741077
}
10751078

1076-
// Always return success to allow handshake to complete
1077-
// The error will be raised during the first I/O operation
1078-
Ok(ClientCertVerified::assertion())
1079+
result
10791080
}
10801081

10811082
fn verify_tls12_signature(

0 commit comments

Comments
 (0)