Skip to content

Commit f52a779

Browse files
committed
session,future: store & use cluster's tokio runtime
This commit finishes the changes required to support configuration of the tokio runtime used by the Scylla Rust wrapper. `cass_cluster_set_num_threads_io` now properly configures the number of threads in the runtime, and the runtime is passed to all futures created by the wrapper. This allows the user to configure the runtime used by the wrapper for each `CassCluster` independently, granting full flexibility.
1 parent c76ad3c commit f52a779

File tree

4 files changed

+78
-9
lines changed

4 files changed

+78
-9
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::RUNTIME;
21
use crate::argconv::*;
32
use crate::cass_error::{CassError, CassErrorMessage, CassErrorResult, ToCassError as _};
43
use crate::prepared::CassPrepared;
@@ -70,6 +69,9 @@ enum FutureKind {
7069
}
7170

7271
struct ResolvableFuture {
72+
/// Runtime used to spawn and execute the future.
73+
runtime: Arc<tokio::runtime::Runtime>,
74+
7375
/// Mutable state of the future that requires synchronized exclusive access
7476
/// in order to ensure thread safety of the future execution.
7577
state: Mutex<CassFutureState>,
@@ -113,12 +115,14 @@ impl CassFuture {
113115
}
114116

115117
pub(crate) fn make_raw(
118+
runtime: Arc<tokio::runtime::Runtime>,
116119
fut: impl Future<Output = CassFutureResult> + Send + 'static,
117120
#[cfg(cpp_integration_testing)] recording_listener: Option<
118121
Arc<crate::integration_testing::RecordingHistoryListener>,
119122
>,
120123
) -> CassOwnedSharedPtr<CassFuture, CMut> {
121124
Self::new_from_future(
125+
runtime,
122126
fut,
123127
#[cfg(cpp_integration_testing)]
124128
recording_listener,
@@ -127,6 +131,7 @@ impl CassFuture {
127131
}
128132

129133
pub(crate) fn new_from_future(
134+
runtime: Arc<tokio::runtime::Runtime>,
130135
fut: impl Future<Output = CassFutureResult> + Send + 'static,
131136
#[cfg(cpp_integration_testing)] recording_listener: Option<
132137
Arc<crate::integration_testing::RecordingHistoryListener>,
@@ -136,6 +141,7 @@ impl CassFuture {
136141
err_string: OnceLock::new(),
137142
kind: FutureKind::Resolvable {
138143
fut: ResolvableFuture {
144+
runtime: Arc::clone(&runtime),
139145
state: Mutex::new(Default::default()),
140146
result: OnceLock::new(),
141147
wait_for_value: Condvar::new(),
@@ -145,7 +151,7 @@ impl CassFuture {
145151
},
146152
});
147153
let cass_fut_clone = Arc::clone(&cass_fut);
148-
let join_handle = RUNTIME.spawn(async move {
154+
let join_handle = runtime.spawn(async move {
149155
let resolvable_fut = match cass_fut_clone.kind {
150156
FutureKind::Resolvable {
151157
fut: ref resolvable,
@@ -239,7 +245,7 @@ impl CassFuture {
239245
// the future.
240246
mem::drop(guard);
241247
// unwrap: JoinError appears only when future either panic'ed or canceled.
242-
RUNTIME.block_on(handle).unwrap();
248+
resolvable_fut.runtime.block_on(handle).unwrap();
243249

244250
// Once we are here, the future is resolved.
245251
// The result is guaranteed to be set.
@@ -331,7 +337,7 @@ impl CassFuture {
331337
future::Either::Right((_, handle)) => Err(JoinHandleTimeout(handle)),
332338
}
333339
};
334-
match RUNTIME.block_on(timed) {
340+
match resolvable_fut.runtime.block_on(timed) {
335341
Err(JoinHandleTimeout(returned_handle)) => {
336342
// We timed out. so we can't finish waiting for the future.
337343
// The problem is that if current thread executor is used,
@@ -677,6 +683,15 @@ mod tests {
677683
time::Duration,
678684
};
679685

686+
fn runtime_for_test() -> Arc<tokio::runtime::Runtime> {
687+
Arc::new(
688+
tokio::runtime::Builder::new_current_thread()
689+
.enable_all()
690+
.build()
691+
.unwrap(),
692+
)
693+
}
694+
680695
// This is not a particularly smart test, but if some thread is granted access the value
681696
// before it is truly computed, then weird things should happen, even a segfault.
682697
// In the incorrect implementation that inspired this test to be written, this test
@@ -685,11 +700,13 @@ mod tests {
685700
#[ntest::timeout(100)]
686701
fn cass_future_thread_safety() {
687702
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
703+
let runtime = runtime_for_test();
688704
let fut = async {
689705
tokio::time::sleep(Duration::from_millis(10)).await;
690706
Err((CassError::CASS_OK, ERROR_MSG.into()))
691707
};
692708
let cass_fut = CassFuture::make_raw(
709+
runtime,
693710
fut,
694711
#[cfg(cpp_integration_testing)]
695712
None,
@@ -724,11 +741,13 @@ mod tests {
724741
fn cass_future_resolves_after_timeout() {
725742
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
726743
const HUNDRED_MILLIS_IN_MICROS: u64 = 100 * 1000;
744+
let runtime = runtime_for_test();
727745
let fut = async move {
728746
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS)).await;
729747
Err((CassError::CASS_OK, ERROR_MSG.into()))
730748
};
731749
let cass_fut = CassFuture::make_raw(
750+
runtime,
732751
fut,
733752
#[cfg(cpp_integration_testing)]
734753
None,
@@ -764,6 +783,8 @@ mod tests {
764783
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
765784
const HUNDRED_MILLIS_IN_MICROS: u64 = 100 * 1000;
766785

786+
let runtime = runtime_for_test();
787+
767788
let create_future_and_flag = || {
768789
unsafe extern "C" fn mark_flag_cb(
769790
_fut: CassBorrowedSharedPtr<CassFuture, CMut>,
@@ -780,6 +801,7 @@ mod tests {
780801
Err((CassError::CASS_OK, ERROR_MSG.into()))
781802
};
782803
let cass_fut = CassFuture::make_raw(
804+
Arc::clone(&runtime),
783805
fut,
784806
#[cfg(cpp_integration_testing)]
785807
None,
@@ -855,7 +877,7 @@ mod tests {
855877
{
856878
let (cass_fut, flag_ptr) = create_future_and_flag();
857879

858-
RUNTIME.block_on(async {
880+
runtime.block_on(async {
859881
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS + 10 * 1000))
860882
.await
861883
});

scylla-rust-wrapper/src/integration_testing.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,27 @@ pub unsafe extern "C" fn testing_future_get_attempted_hosts(
280280
unsafe { CString::from_vec_unchecked(concatenated_hosts.into_bytes()) }.into_raw()
281281
}
282282

283+
#[cfg(test)]
284+
fn runtime_for_test() -> Arc<tokio::runtime::Runtime> {
285+
Arc::new(
286+
tokio::runtime::Builder::new_current_thread()
287+
.enable_all()
288+
.build()
289+
.unwrap(),
290+
)
291+
}
292+
283293
/// Ensures that the `testing_future_get_attempted_hosts` function
284294
/// behaves correctly, i.e., it returns a list of attempted hosts as a concatenated string.
285295
#[test]
286296
fn test_future_get_attempted_hosts() {
287297
use scylla::observability::history::HistoryListener as _;
288298

299+
let runtime = runtime_for_test();
300+
289301
let listener = Arc::new(RecordingHistoryListener::new());
290-
let future = CassFuture::new_from_future(std::future::pending(), Some(listener.clone()));
302+
let future =
303+
CassFuture::new_from_future(runtime, std::future::pending(), Some(listener.clone()));
291304

292305
fn assert_attempted_hosts_eq(future: &Arc<CassFuture>, hosts: &[String]) {
293306
let hosts_str = unsafe { testing_future_get_attempted_hosts(ArcFFI::as_ptr(future)) };

scylla-rust-wrapper/src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use crate::logging::Logger;
44
use crate::logging::stderr_log_callback;
55
use std::sync::LazyLock;
66
use std::sync::RwLock;
7-
use tokio::runtime::Runtime;
87

98
#[macro_use]
109
mod binding;
@@ -190,7 +189,6 @@ pub(crate) mod cass_version_types {
190189
include_bindgen_generated!("cppdriver_version_types.rs");
191190
}
192191

193-
pub(crate) static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
194192
pub(crate) static LOGGER: LazyLock<RwLock<Logger>> = LazyLock::new(|| {
195193
RwLock::new(Logger {
196194
cb: Some(stderr_log_callback),

scylla-rust-wrapper/src/session.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use std::sync::Arc;
3131
use tokio::sync::RwLock;
3232

3333
pub(crate) struct CassConnectedSession {
34+
runtime: Arc<tokio::runtime::Runtime>,
3435
session: Session,
3536
exec_profile_map: HashMap<ExecProfileName, ExecutionProfileHandle>,
3637
}
@@ -78,6 +79,7 @@ impl CassConnectedSession {
7879
let cluster_client_id = cluster.get_client_id();
7980

8081
let fut = Self::connect_fut(
82+
Arc::clone(cluster.get_runtime()),
8183
session,
8284
session_builder,
8385
cluster_client_id,
@@ -87,13 +89,15 @@ impl CassConnectedSession {
8789
);
8890

8991
CassFuture::make_raw(
92+
Arc::clone(cluster.get_runtime()),
9093
fut,
9194
#[cfg(cpp_integration_testing)]
9295
None,
9396
)
9497
}
9598

9699
async fn connect_fut(
100+
runtime: Arc<tokio::runtime::Runtime>,
97101
session: Arc<CassSession>,
98102
session_builder_fut: impl Future<Output = SessionBuilder>,
99103
cluster_client_id: Option<uuid::Uuid>,
@@ -152,20 +156,40 @@ impl CassConnectedSession {
152156
.map_err(|err| (err.to_cass_error(), err.msg()))?;
153157

154158
session_guard.connected = Some(CassConnectedSession {
159+
runtime,
155160
session,
156161
exec_profile_map,
157162
});
158163
Ok(CassResultValue::Empty)
159164
}
160165

161166
fn close_fut(session_opt: Arc<RwLock<CassSessionInner>>) -> Arc<CassFuture> {
167+
let runtime = {
168+
let Ok(session_guard) = session_opt.try_read() else {
169+
return CassFuture::new_ready(Err((
170+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
171+
"Still connecting or already closing".msg(),
172+
)));
173+
};
174+
175+
let Some(connected_session) = session_guard.connected.as_ref() else {
176+
return CassFuture::new_ready(Err((
177+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
178+
"Session is not connected".msg(),
179+
)));
180+
};
181+
182+
Arc::clone(&connected_session.runtime)
183+
};
184+
162185
CassFuture::new_from_future(
186+
runtime,
163187
async move {
164188
let mut session_guard = session_opt.write().await;
165189
if session_guard.connected.is_none() {
166190
return Err((
167191
CassError::CASS_ERROR_LIB_UNABLE_TO_CLOSE,
168-
"Already closing or closed".msg(),
192+
"Session is not connected".msg(),
169193
));
170194
}
171195

@@ -286,6 +310,8 @@ pub unsafe extern "C" fn cass_session_execute_batch(
286310
)));
287311
};
288312

313+
let runtime = Arc::clone(&connected_session.runtime);
314+
289315
let future = async move {
290316
let connected_session = session_guard
291317
.connected
@@ -315,6 +341,7 @@ pub unsafe extern "C" fn cass_session_execute_batch(
315341
};
316342

317343
CassFuture::make_raw(
344+
runtime,
318345
future,
319346
#[cfg(cpp_integration_testing)]
320347
None,
@@ -350,6 +377,8 @@ pub unsafe extern "C" fn cass_session_execute(
350377
)));
351378
};
352379

380+
let runtime = Arc::clone(&connected_session.runtime);
381+
353382
let paging_state = statement_opt.paging_state.clone();
354383
let paging_enabled = statement_opt.paging_enabled;
355384
let mut statement = statement_opt.statement.clone();
@@ -483,6 +512,7 @@ pub unsafe extern "C" fn cass_session_execute(
483512
};
484513

485514
CassFuture::make_raw(
515+
runtime,
486516
future,
487517
#[cfg(cpp_integration_testing)]
488518
recording_listener,
@@ -517,9 +547,12 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
517547
)));
518548
};
519549

550+
let runtime = Arc::clone(&connected_session.runtime);
551+
520552
let statement = cass_statement.statement.clone();
521553

522554
CassFuture::make_raw(
555+
runtime,
523556
async move {
524557
let query = match &statement {
525558
BoundStatement::Simple(q) => q,
@@ -588,6 +621,8 @@ pub unsafe extern "C" fn cass_session_prepare_n(
588621
)));
589622
};
590623

624+
let runtime = Arc::clone(&connected_session.runtime);
625+
591626
let query = Statement::new(query_str.to_string());
592627

593628
let fut = async move {
@@ -608,6 +643,7 @@ pub unsafe extern "C" fn cass_session_prepare_n(
608643
};
609644

610645
CassFuture::make_raw(
646+
runtime,
611647
fut,
612648
#[cfg(cpp_integration_testing)]
613649
None,

0 commit comments

Comments
 (0)