Skip to content

Commit b9b57ee

Browse files
committed
session: lock session for reading nonblockingly
The mechanism around locking the session has been problematic in the number of ways. **1. Late locking**: The primary issue was that requests would not lock lock the session for reading until their futures were polled, which meant that the following code could lead to request failure due to session having been closed: ```c CassFuture *exec_fut = cass_session_execute(cass_session, cass_statement); CassFuture *close_fut = cass_session_close(cass_session); cass_future_wait(exec_fut); cass_future_wait(close_fut); ``` This is because locking the session for reading is done asynchronously to the code that follows `cass_session_execute` invocation. The same issue was present in all request-making functions, that is, `cass_session_execute`, `cass_session_execute_batch`, and all of the `cass_session_prepare*` family of functions. **2. Thread-blocking locking**: The second issue was that, in some cases, the session was locked for reading by blocking the current thread idly. This was the case for `cass_session_get_metrics`, `cass_session_get_schema_meta`, and `cass_session_get_client_id`. This could lead to deadlocks, especially when the number of threads in the thread pool were low (with `current_thread` tokio executor being the most vulnerable case). **3. Runtime-blocking locking**: The third issue was that, in the case of `cass_session_connect*` functions, the session was locked for writing by blocking the current thread as the executor thread for the awaited future. While this was designed with the `current_thread` executor in mind and worked perfectly for its case, it showed to cause panic when called by a tokio executor thread. **Solution**: This commit addresses all of the above issues by adopting an asymmetric locking mechanism for the session. The session is now locked for reading in advance yet fallibly, by calling `try_read(_owned)` on the session rwlock. This is done in the request-making functions, so that the session is guaranteed to be locked for reading when the request future is returned. The session is still locked for writing (upon connecting or closing) asynchronously, by calling and awaiting `write(_owned)` on the rwlock. This is done in the `cass_session_connect*` and `cass_session_close`. A downside of this approach is that the session is not guaranteed to be locked for writing when the `cass_session_connect*` or `cass_session_close` futures are returned, but this is not a problem because closing and connecting are considered to be "long-running", complex operations and thus are not expected to have conducted a specific part of their logic by the time their future is returned. **Results**: All enabled tests still pass, while the `callbacks` example now passes, too! This best part is that the number and complexity of the required changes is minimal, and the code is now much more robust. I hope @Lorak-mmk will be happy with this solution, as compared to the complex requests pending mechanism and atomics. Note that the test for `cass_session_get_client_id` had to be adjusted. This is because the session has the client ID set only in the connect future instead of in the connect function synchronously, so the test (which did not await the connect future) would fail after the changes.
1 parent 6c5b2a2 commit b9b57ee

File tree

1 file changed

+118
-48
lines changed

1 file changed

+118
-48
lines changed

scylla-rust-wrapper/src/session.rs

Lines changed: 118 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::RUNTIME;
21
use crate::argconv::*;
32
use crate::batch::CassBatch;
43
use crate::cass_error::*;
@@ -29,7 +28,7 @@ use std::future::Future;
2928
use std::ops::Deref;
3029
use std::os::raw::c_char;
3130
use std::sync::Arc;
32-
use tokio::sync::{OwnedRwLockWriteGuard, RwLock};
31+
use tokio::sync::RwLock;
3332

3433
pub(crate) struct CassConnectedSession {
3534
session: Session,
@@ -70,24 +69,19 @@ impl CassConnectedSession {
7069
}
7170

7271
fn connect(
73-
session: Arc<RwLock<CassSessionInner>>,
72+
session: Arc<CassSession>,
7473
cluster: &CassCluster,
7574
keyspace: Option<String>,
7675
) -> CassOwnedSharedPtr<CassFuture, CMut> {
7776
let session_builder = cluster.build_session_builder();
7877
let exec_profile_map = cluster.execution_profile_map().clone();
7978
let host_filter = cluster.build_host_filter();
80-
81-
let mut session_guard = RUNTIME.block_on(session.write_owned());
82-
83-
if let Some(cluster_client_id) = cluster.get_client_id() {
84-
// If the user set a client id, use it instead of the random one.
85-
session_guard.client_id = cluster_client_id;
86-
}
79+
let cluster_client_id = cluster.get_client_id();
8780

8881
let fut = Self::connect_fut(
89-
session_guard,
82+
session,
9083
session_builder,
84+
cluster_client_id,
9185
exec_profile_map,
9286
host_filter,
9387
keyspace,
@@ -101,21 +95,29 @@ impl CassConnectedSession {
10195
}
10296

10397
async fn connect_fut(
104-
mut session_guard: OwnedRwLockWriteGuard<CassSessionInner>,
98+
session: Arc<CassSession>,
10599
session_builder_fut: impl Future<Output = SessionBuilder>,
100+
cluster_client_id: Option<uuid::Uuid>,
106101
exec_profile_builder_map: HashMap<ExecProfileName, CassExecProfile>,
107102
host_filter: Arc<dyn HostFilter>,
108103
keyspace: Option<String>,
109104
) -> CassFutureResult {
110105
// This can sleep for a long time, but only if someone connects/closes session
111106
// from more than 1 thread concurrently, which is inherently stupid thing to do.
107+
let mut session_guard = session.write().await;
108+
112109
if session_guard.connected.is_some() {
113110
return Err((
114111
CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
115112
"Already connecting, closing, or connected".msg(),
116113
));
117114
}
118115

116+
if let Some(cluster_client_id) = cluster_client_id {
117+
// If the user set a client id, use it instead of the random one.
118+
session_guard.client_id = cluster_client_id;
119+
}
120+
119121
let mut session_builder = session_builder_fut.await;
120122
let default_profile = session_builder
121123
.config
@@ -268,6 +270,13 @@ pub unsafe extern "C" fn cass_session_execute_batch(
268270
return ArcFFI::null();
269271
};
270272

273+
let Ok(session_guard) = session_opt.try_read_owned() else {
274+
return CassFuture::make_ready_raw(Err((
275+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
276+
"Session is not connected".msg(),
277+
)));
278+
};
279+
271280
let mut state = batch_from_raw.state.clone();
272281

273282
// DO NOT refer to `batch_from_raw` inside the async block, as I've done just to face a segfault.
@@ -276,7 +285,6 @@ pub unsafe extern "C" fn cass_session_execute_batch(
276285
let batch_from_raw = (); // Hardening shadow to avoid use-after-free.
277286

278287
let future = async move {
279-
let session_guard = session_opt.read().await;
280288
if session_guard.connected.is_none() {
281289
return Err((
282290
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
@@ -330,6 +338,13 @@ pub unsafe extern "C" fn cass_session_execute(
330338
return ArcFFI::null();
331339
};
332340

341+
let Ok(session_guard) = session_opt.try_read_owned() else {
342+
return CassFuture::make_ready_raw(Err((
343+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
344+
"Session is not connected".msg(),
345+
)));
346+
};
347+
333348
let paging_state = statement_opt.paging_state.clone();
334349
let paging_enabled = statement_opt.paging_enabled;
335350
let mut statement = statement_opt.statement.clone();
@@ -360,7 +375,6 @@ pub unsafe extern "C" fn cass_session_execute(
360375
let statement_opt = (); // Hardening shadow to avoid use-after-free.
361376

362377
let future = async move {
363-
let session_guard = session_opt.read().await;
364378
let Some(cass_connected_session) = session_guard.connected.as_ref() else {
365379
return Err((
366380
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
@@ -488,6 +502,13 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
488502
return ArcFFI::null();
489503
};
490504

505+
let Ok(session_guard) = session.try_read_owned() else {
506+
return CassFuture::make_ready_raw(Err((
507+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
508+
"Session is not connected".msg(),
509+
)));
510+
};
511+
491512
let statement = cass_statement.statement.clone();
492513

493514
CassFuture::make_raw(
@@ -499,7 +520,6 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
499520
}
500521
};
501522

502-
let session_guard = session.read().await;
503523
if session_guard.connected.is_none() {
504524
return Err((
505525
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
@@ -546,10 +566,17 @@ pub unsafe extern "C" fn cass_session_prepare_n(
546566
// to receive a server error in such case (CASS_ERROR_SERVER_SYNTAX_ERROR).
547567
// There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`.
548568
.unwrap_or_default();
569+
570+
let Ok(session_guard) = cass_session.try_read_owned() else {
571+
return CassFuture::make_ready_raw(Err((
572+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
573+
"Session is not connected".msg(),
574+
)));
575+
};
576+
549577
let query = Statement::new(query_str.to_string());
550578

551579
let fut = async move {
552-
let session_guard = cass_session.read().await;
553580
if session_guard.connected.is_none() {
554581
return Err((
555582
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
@@ -610,22 +637,43 @@ pub unsafe extern "C" fn cass_session_get_client_id(
610637
return uuid::Uuid::nil().into();
611638
};
612639

613-
let client_id: uuid::Uuid = cass_session.blocking_read().client_id;
640+
let Ok(session_guard) = cass_session.try_read() else {
641+
tracing::error!(
642+
"Called cass_session_get_client_id on a connecting/disconnecting session!\
643+
Wait for it to finish first."
644+
);
645+
return uuid::Uuid::nil().into();
646+
};
647+
648+
let client_id: uuid::Uuid = session_guard.client_id;
614649
client_id.into()
615650
}
616651

617652
#[unsafe(no_mangle)]
618653
pub unsafe extern "C" fn cass_session_get_schema_meta(
619654
session: CassBorrowedSharedPtr<CassSession, CConst>,
620655
) -> CassOwnedExclusivePtr<CassSchemaMeta, CConst> {
621-
let cass_session = ArcFFI::as_ref(session).unwrap();
656+
let Some(cass_session) = ArcFFI::as_ref(session) else {
657+
tracing::error!("Provided null session pointer to cass_session_get_schema_meta!");
658+
return CassPtr::null();
659+
};
660+
661+
let Ok(session_guard) = cass_session.try_read() else {
662+
tracing::error!(
663+
"Called cass_session_get_schema_meta on a connecting/disconnecting session!\
664+
Wait for it to finish first."
665+
);
666+
return CassPtr::null();
667+
};
668+
669+
let Some(cass_connected_session) = session_guard.connected.as_ref() else {
670+
tracing::error!("Called cass_session_get_schema_meta on a disconnected session!");
671+
return CassPtr::null();
672+
};
673+
622674
let mut keyspaces: HashMap<String, CassKeyspaceMeta> = HashMap::new();
623675

624-
for (keyspace_name, keyspace) in cass_session
625-
.blocking_read()
626-
.connected
627-
.as_ref()
628-
.unwrap()
676+
for (keyspace_name, keyspace) in cass_connected_session
629677
.session
630678
.get_cluster_state()
631679
.keyspaces_iter()
@@ -699,9 +747,14 @@ pub unsafe extern "C" fn cass_session_get_metrics(
699747
return;
700748
}
701749

702-
let maybe_session_guard = maybe_session_lock.blocking_read();
703-
let maybe_session = maybe_session_guard.connected.as_ref();
704-
let Some(session) = maybe_session else {
750+
let Ok(session_guard) = maybe_session_lock.try_read() else {
751+
tracing::error!(
752+
"Called cass_session_get_metrics on a connecting/disconnecting session!\
753+
Wait for it to finish first."
754+
);
755+
return;
756+
};
757+
let Some(session) = session_guard.connected.as_ref() else {
705758
tracing::warn!("Attempted to get metrics before connecting session object");
706759
return;
707760
};
@@ -1785,34 +1838,51 @@ mod tests {
17851838
}
17861839
}
17871840

1788-
#[test]
1841+
#[tokio::test]
17891842
#[ntest::timeout(5000)]
1790-
fn test_cass_session_get_client_id_on_disconnected_session() {
1843+
async fn test_cass_session_get_client_id_on_disconnected_session() {
17911844
init_logger();
1792-
unsafe {
1793-
let session_raw = cass_session_new();
1845+
test_with_one_proxy(
1846+
|node_addr: SocketAddr, proxy: RunningProxy| unsafe {
1847+
let session_raw = cass_session_new();
17941848

1795-
// Check that we can get a client ID from a disconnected session.
1796-
let _random_client_id = cass_session_get_client_id(session_raw.borrow());
1849+
// Check that we can get a client ID from a disconnected session.
1850+
let _random_client_id = cass_session_get_client_id(session_raw.borrow());
17971851

1798-
let mut cluster_raw = cass_cluster_new();
1799-
let cluster_client_id = CassUuid {
1800-
time_and_version: 2137,
1801-
clock_seq_and_node: 7312,
1802-
};
1803-
cass_cluster_set_client_id(cluster_raw.borrow_mut(), cluster_client_id);
1852+
let mut cluster_raw = cass_cluster_new();
1853+
let ip = node_addr.ip().to_string();
1854+
let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str());
1855+
assert_cass_error_eq!(
1856+
cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len),
1857+
CassError::CASS_OK
1858+
);
1859+
1860+
let cluster_client_id = CassUuid {
1861+
time_and_version: 2137,
1862+
clock_seq_and_node: 7312,
1863+
};
1864+
cass_cluster_set_client_id(cluster_raw.borrow_mut(), cluster_client_id);
18041865

1805-
cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const());
1806-
// Verify that the session inherits the client ID from the cluster.
1807-
let session_client_id = cass_session_get_client_id(session_raw.borrow());
1808-
assert_eq!(session_client_id, cluster_client_id);
1866+
let connect_fut =
1867+
cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const());
1868+
assert_cass_error_eq!(cass_future_error_code(connect_fut), CassError::CASS_OK);
18091869

1810-
// Verify that we can still get a client ID after disconnecting.
1811-
let session_client_id = cass_session_get_client_id(session_raw.borrow());
1812-
assert_eq!(session_client_id, cluster_client_id);
1870+
// Verify that the session inherits the client ID from the cluster.
1871+
let session_client_id = cass_session_get_client_id(session_raw.borrow());
1872+
assert_eq!(session_client_id, cluster_client_id);
18131873

1814-
cass_session_free(session_raw);
1815-
cass_cluster_free(cluster_raw)
1816-
}
1874+
// Verify that we can still get a client ID after disconnecting.
1875+
let session_client_id = cass_session_get_client_id(session_raw.borrow());
1876+
assert_eq!(session_client_id, cluster_client_id);
1877+
1878+
cass_session_free(session_raw);
1879+
cass_cluster_free(cluster_raw);
1880+
1881+
proxy
1882+
},
1883+
mock_init_rules(),
1884+
)
1885+
.with_current_subscriber()
1886+
.await;
18171887
}
18181888
}

0 commit comments

Comments
 (0)