Skip to content

Commit 63d2949

Browse files
committed
session,future: store & use cluster's tokio runtime
This commit finishes the changes required to support configuration of the tokio runtime used by the Scylla Rust wrapper. `cass_cluster_set_num_threads_io` now properly configures the number of threads in the runtime, and the runtime is passed to all futures created by the wrapper. This allows the user to configure the runtime used by the wrapper for each `CassCluster` independently, granting full flexibility.
1 parent 2919e14 commit 63d2949

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::RUNTIME;
21
use crate::argconv::*;
32
use crate::cass_error::CassError;
43
use crate::cass_error::CassErrorMessage;
@@ -68,6 +67,9 @@ enum FutureKind {
6867
}
6968

7069
struct ResolvableFuture {
70+
/// Runtime used to spawn and execute the future.
71+
runtime: Arc<tokio::runtime::Runtime>,
72+
7173
/// Mutable state of the future that requires synchronized exclusive access
7274
/// in order to ensure thread safety of the future execution.
7375
state: Mutex<CassFutureState>,
@@ -111,12 +113,14 @@ impl CassFuture {
111113
}
112114

113115
pub(crate) fn make_raw(
116+
runtime: Arc<tokio::runtime::Runtime>,
114117
fut: impl Future<Output = CassFutureResult> + Send + 'static,
115118
#[cfg(cpp_integration_testing)] recording_listener: Option<
116119
Arc<crate::integration_testing::RecordingHistoryListener>,
117120
>,
118121
) -> CassOwnedSharedPtr<CassFuture, CMut> {
119122
Self::new_from_future(
123+
runtime,
120124
fut,
121125
#[cfg(cpp_integration_testing)]
122126
recording_listener,
@@ -125,6 +129,7 @@ impl CassFuture {
125129
}
126130

127131
pub(crate) fn new_from_future(
132+
runtime: Arc<tokio::runtime::Runtime>,
128133
fut: impl Future<Output = CassFutureResult> + Send + 'static,
129134
#[cfg(cpp_integration_testing)] recording_listener: Option<
130135
Arc<crate::integration_testing::RecordingHistoryListener>,
@@ -133,6 +138,7 @@ impl CassFuture {
133138
let cass_fut = Arc::new(CassFuture {
134139
err_string: OnceLock::new(),
135140
kind: FutureKind::Resolvable(ResolvableFuture {
141+
runtime: Arc::clone(&runtime),
136142
state: Mutex::new(Default::default()),
137143
result: OnceLock::new(),
138144
wait_for_value: Condvar::new(),
@@ -141,7 +147,7 @@ impl CassFuture {
141147
}),
142148
});
143149
let cass_fut_clone = Arc::clone(&cass_fut);
144-
let join_handle = RUNTIME.spawn(async move {
150+
let join_handle = runtime.spawn(async move {
145151
let resolvable_fut = match cass_fut_clone.kind {
146152
FutureKind::Resolvable(ref resolvable) => resolvable,
147153
_ => unreachable!("CassFuture has been created as Resolvable"),
@@ -223,7 +229,7 @@ impl CassFuture {
223229
// the future.
224230
mem::drop(guard);
225231
// unwrap: JoinError appears only when future either panic'ed or canceled.
226-
RUNTIME.block_on(handle).unwrap();
232+
resolvable_fut.runtime.block_on(handle).unwrap();
227233

228234
// Once we are here, the future is resolved.
229235
// The result is guaranteed to be set.
@@ -313,7 +319,7 @@ impl CassFuture {
313319
future::Either::Right((_, handle)) => Err(JoinHandleTimeout(handle)),
314320
}
315321
};
316-
match RUNTIME.block_on(timed) {
322+
match resolvable_fut.runtime.block_on(timed) {
317323
Err(JoinHandleTimeout(returned_handle)) => {
318324
// We timed out. so we can't finish waiting for the future.
319325
// The problem is that if current thread executor is used,
@@ -638,6 +644,15 @@ mod tests {
638644
time::Duration,
639645
};
640646

647+
fn runtime_for_test() -> Arc<tokio::runtime::Runtime> {
648+
Arc::new(
649+
tokio::runtime::Builder::new_current_thread()
650+
.enable_all()
651+
.build()
652+
.unwrap(),
653+
)
654+
}
655+
641656
// This is not a particularly smart test, but if some thread is granted access the value
642657
// before it is truly computed, then weird things should happen, even a segfault.
643658
// In the incorrect implementation that inspired this test to be written, this test
@@ -646,11 +661,13 @@ mod tests {
646661
#[ntest::timeout(100)]
647662
fn cass_future_thread_safety() {
648663
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
664+
let runtime = runtime_for_test();
649665
let fut = async {
650666
tokio::time::sleep(Duration::from_millis(10)).await;
651667
Err((CassError::CASS_OK, ERROR_MSG.into()))
652668
};
653669
let cass_fut = CassFuture::make_raw(
670+
runtime,
654671
fut,
655672
#[cfg(cpp_integration_testing)]
656673
None,
@@ -685,11 +702,13 @@ mod tests {
685702
fn cass_future_resolves_after_timeout() {
686703
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
687704
const HUNDRED_MILLIS_IN_MICROS: u64 = 100 * 1000;
705+
let runtime = runtime_for_test();
688706
let fut = async move {
689707
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS)).await;
690708
Err((CassError::CASS_OK, ERROR_MSG.into()))
691709
};
692710
let cass_fut = CassFuture::make_raw(
711+
runtime,
693712
fut,
694713
#[cfg(cpp_integration_testing)]
695714
None,
@@ -725,6 +744,8 @@ mod tests {
725744
const ERROR_MSG: &str = "NOBODY EXPECTED SPANISH INQUISITION";
726745
const HUNDRED_MILLIS_IN_MICROS: u64 = 100 * 1000;
727746

747+
let runtime = runtime_for_test();
748+
728749
let create_future_and_flag = || {
729750
unsafe extern "C" fn mark_flag_cb(
730751
_fut: CassBorrowedSharedPtr<CassFuture, CMut>,
@@ -741,6 +762,7 @@ mod tests {
741762
Err((CassError::CASS_OK, ERROR_MSG.into()))
742763
};
743764
let cass_fut = CassFuture::make_raw(
765+
Arc::clone(&runtime),
744766
fut,
745767
#[cfg(cpp_integration_testing)]
746768
None,
@@ -816,7 +838,7 @@ mod tests {
816838
{
817839
let (cass_fut, flag_ptr) = create_future_and_flag();
818840

819-
RUNTIME.block_on(async {
841+
runtime.block_on(async {
820842
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS + 10 * 1000))
821843
.await
822844
});

scylla-rust-wrapper/src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use crate::logging::Logger;
44
use crate::logging::stderr_log_callback;
55
use std::sync::LazyLock;
66
use std::sync::RwLock;
7-
use tokio::runtime::Runtime;
87

98
#[macro_use]
109
mod binding;
@@ -191,7 +190,6 @@ pub(crate) mod cass_version_types {
191190
include_bindgen_generated!("cppdriver_version_types.rs");
192191
}
193192

194-
pub(crate) static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
195193
pub(crate) static LOGGER: LazyLock<RwLock<Logger>> = LazyLock::new(|| {
196194
RwLock::new(Logger {
197195
cb: Some(stderr_log_callback),

scylla-rust-wrapper/src/session.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use std::sync::Arc;
3131
use tokio::sync::RwLock;
3232

3333
pub(crate) struct CassConnectedSession {
34+
runtime: Arc<tokio::runtime::Runtime>,
3435
session: Session,
3536
exec_profile_map: HashMap<ExecProfileName, ExecutionProfileHandle>,
3637
}
@@ -78,6 +79,7 @@ impl CassConnectedSession {
7879
let cluster_client_id = cluster.get_client_id();
7980

8081
let fut = Self::connect_fut(
82+
Arc::clone(cluster.get_runtime()),
8183
session,
8284
session_builder,
8385
cluster_client_id,
@@ -87,13 +89,15 @@ impl CassConnectedSession {
8789
);
8890

8991
CassFuture::make_raw(
92+
Arc::clone(cluster.get_runtime()),
9093
fut,
9194
#[cfg(cpp_integration_testing)]
9295
None,
9396
)
9497
}
9598

9699
async fn connect_fut(
100+
runtime: Arc<tokio::runtime::Runtime>,
97101
session: Arc<CassSession>,
98102
session_builder_fut: impl Future<Output = SessionBuilder>,
99103
cluster_client_id: Option<uuid::Uuid>,
@@ -152,20 +156,40 @@ impl CassConnectedSession {
152156
.map_err(|err| (err.to_cass_error(), err.msg()))?;
153157

154158
session_guard.connected = Some(CassConnectedSession {
159+
runtime,
155160
session,
156161
exec_profile_map,
157162
});
158163
Ok(CassResultValue::Empty)
159164
}
160165

161166
fn close_fut(session_opt: Arc<RwLock<CassSessionInner>>) -> Arc<CassFuture> {
167+
let runtime = {
168+
let Ok(session_guard) = session_opt.try_read() else {
169+
return CassFuture::new_ready(Err((
170+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
171+
"Still connecting or already closing".msg(),
172+
)));
173+
};
174+
175+
let Some(connected_session) = session_guard.connected.as_ref() else {
176+
return CassFuture::new_ready(Err((
177+
CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
178+
"Session is not connected".msg(),
179+
)));
180+
};
181+
182+
Arc::clone(&connected_session.runtime)
183+
};
184+
162185
CassFuture::new_from_future(
186+
runtime,
163187
async move {
164188
let mut session_guard = session_opt.write().await;
165189
if session_guard.connected.is_none() {
166190
return Err((
167191
CassError::CASS_ERROR_LIB_UNABLE_TO_CLOSE,
168-
"Already closing or closed".msg(),
192+
"Session is not connected".msg(),
169193
));
170194
}
171195

@@ -286,6 +310,8 @@ pub unsafe extern "C" fn cass_session_execute_batch(
286310
)));
287311
};
288312

313+
let runtime = Arc::clone(&connected_session.runtime);
314+
289315
let future = async move {
290316
let connected_session = session_guard
291317
.connected
@@ -315,6 +341,7 @@ pub unsafe extern "C" fn cass_session_execute_batch(
315341
};
316342

317343
CassFuture::make_raw(
344+
runtime,
318345
future,
319346
#[cfg(cpp_integration_testing)]
320347
None,
@@ -350,6 +377,8 @@ pub unsafe extern "C" fn cass_session_execute(
350377
)));
351378
};
352379

380+
let runtime = Arc::clone(&connected_session.runtime);
381+
353382
let paging_state = statement_opt.paging_state.clone();
354383
let paging_enabled = statement_opt.paging_enabled;
355384
let mut statement = statement_opt.statement.clone();
@@ -483,6 +512,7 @@ pub unsafe extern "C" fn cass_session_execute(
483512
};
484513

485514
CassFuture::make_raw(
515+
runtime,
486516
future,
487517
#[cfg(cpp_integration_testing)]
488518
recording_listener,
@@ -517,9 +547,12 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
517547
)));
518548
};
519549

550+
let runtime = Arc::clone(&connected_session.runtime);
551+
520552
let statement = cass_statement.statement.clone();
521553

522554
CassFuture::make_raw(
555+
runtime,
523556
async move {
524557
let query = match &statement {
525558
BoundStatement::Simple(q) => q,
@@ -588,6 +621,8 @@ pub unsafe extern "C" fn cass_session_prepare_n(
588621
)));
589622
};
590623

624+
let runtime = Arc::clone(&connected_session.runtime);
625+
591626
let query = Statement::new(query_str.to_string());
592627

593628
let fut = async move {
@@ -608,6 +643,7 @@ pub unsafe extern "C" fn cass_session_prepare_n(
608643
};
609644

610645
CassFuture::make_raw(
646+
runtime,
611647
fut,
612648
#[cfg(cpp_integration_testing)]
613649
None,

0 commit comments

Comments
 (0)