Skip to content

Commit a9a5ef2

Browse files
committed
runtime: introduce cache for tokio runtimes
As @Lorak-mmk noted in the review, runtimes are currently not shared between `CassCluster` instances, which leads to possible many tokio runtimes being created in the application, with possibly a lot of threads. This commit introduces a cache for tokio runtimes, which is encapsulated in the global `Runtimes` struct. 1. `CassCluster` now does not store a `tokio::runtime::Runtime` directly, but rather an optional number of threads in the runtime. 2. The `Runtimes` struct is a global cache for tokio runtimes. It allows to get a default runtime or a runtime with a specified number of threads. If a runtime is not created yet, it will create a new one and cache it for future use. The handling of the cache is fully transparent to the user of the abstraction. 3. Once all `CassCluster` instances that reference a runtime are dropped, the runtime is also dropped. This is done by storing weak pointers to runtimes in the `Runtimes` struct. Interesting to note: as Weak pointers keep the Arc allocation alive, a workflow that for consecutive `i`s creates a `CassCluster` with a runtime with `i` threads and then drops it, will lead to space leaks. This is an artificial case, though.
1 parent b3d748d commit a9a5ef2

File tree

4 files changed

+106
-31
lines changed

4 files changed

+106
-31
lines changed

scylla-rust-wrapper/src/cluster.rs

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::load_balancing::{
77
CassHostFilter, DcRestriction, LoadBalancingConfig, LoadBalancingKind,
88
};
99
use crate::retry_policy::CassRetryPolicy;
10+
use crate::runtime::RUNTIMES;
1011
use crate::ssl::CassSsl;
1112
use crate::timestamp_generator::CassTimestampGen;
1213
use crate::types::*;
@@ -82,7 +83,11 @@ const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
8283
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");
8384

8485
pub struct CassCluster {
85-
runtime: Arc<tokio::runtime::Runtime>,
86+
/// Number of threads in the tokio runtime thread pool.
87+
///
88+
/// Specified with `cass_cluster_set_num_threads_io`.
89+
/// If not set, the default tokio runtime is used.
90+
num_threads_io: Option<usize>,
8691

8792
session_builder: SessionBuilder,
8893
default_execution_profile_builder: ExecutionProfileBuilder,
@@ -101,8 +106,20 @@ pub struct CassCluster {
101106
}
102107

103108
impl CassCluster {
104-
pub(crate) fn get_runtime(&self) -> &Arc<tokio::runtime::Runtime> {
105-
&self.runtime
109+
/// Gets the runtime that has been set for the cluster.
110+
/// If no runtime has been set yet, it creates a default runtime
111+
/// and makes it cached in the global `Runtimes` instance.
112+
pub(crate) fn get_runtime(&self) -> Arc<tokio::runtime::Runtime> {
113+
let mut runtimes = RUNTIMES.lock().unwrap();
114+
115+
if let Some(num_threads_io) = self.num_threads_io {
116+
// If the number of threads is set, we create a runtime with that number of threads.
117+
runtimes.n_thread_runtime(num_threads_io)
118+
} else {
119+
// Otherwise, we use the default runtime.
120+
runtimes.default_runtime()
121+
}
122+
.expect("Failed to create an async runtime")
106123
}
107124

108125
pub(crate) fn execution_profile_map(&self) -> &HashMap<ExecProfileName, CassExecProfile> {
@@ -184,12 +201,6 @@ impl CassCluster {
184201

185202
#[unsafe(no_mangle)]
186203
pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster, CMut> {
187-
let Ok(default_runtime) = tokio::runtime::Runtime::new()
188-
.inspect_err(|e| tracing::error!("Failed to create async runtime: {}", e))
189-
else {
190-
return CassPtr::null_mut();
191-
};
192-
193204
let default_execution_profile_builder = ExecutionProfileBuilder::default()
194205
.consistency(DEFAULT_CONSISTENCY)
195206
.serial_consistency(DEFAULT_SERIAL_CONSISTENCY)
@@ -321,7 +332,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
321332
};
322333

323334
BoxFFI::into_ptr(Box::new(CassCluster {
324-
runtime: Arc::new(default_runtime),
335+
num_threads_io: None,
325336

326337
session_builder: default_session_builder,
327338
port: 9042,
@@ -1554,25 +1565,7 @@ pub unsafe extern "C" fn cass_cluster_set_num_threads_io(
15541565
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
15551566
};
15561567

1557-
let runtime_res = match num_threads {
1558-
0 => tokio::runtime::Builder::new_current_thread()
1559-
.enable_all()
1560-
.build(),
1561-
n => tokio::runtime::Builder::new_multi_thread()
1562-
.worker_threads(n as usize)
1563-
.enable_all()
1564-
.build(),
1565-
};
1566-
1567-
let runtime = match runtime_res {
1568-
Ok(runtime) => runtime,
1569-
Err(err) => {
1570-
tracing::error!("Failed to create async runtime: {}", err);
1571-
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
1572-
}
1573-
};
1574-
1575-
cluster.runtime = Arc::new(runtime);
1568+
cluster.num_threads_io = Some(num_threads as usize);
15761569

15771570
CassError::CASS_OK
15781571
}

scylla-rust-wrapper/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub(crate) mod misc;
3030
pub(crate) mod prepared;
3131
pub(crate) mod query_result;
3232
pub(crate) mod retry_policy;
33+
pub(crate) mod runtime;
3334
#[cfg(test)]
3435
mod ser_de_tests;
3536
pub(crate) mod session;

scylla-rust-wrapper/src/runtime.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//! Manages tokio runtimes for the application.
2+
//!
3+
//! Runtime is per-cluster and can be changed with `cass_cluster_set_num_threads_io`.
4+
5+
use std::{
6+
collections::HashMap,
7+
sync::{Arc, Weak},
8+
};
9+
10+
use tokio::runtime::Runtime;
11+
12+
/// Manages tokio runtimes for the application.
13+
///
14+
/// Runtime is per-cluster and can be changed with `cass_cluster_set_num_threads_io`.
15+
/// Once a runtime is created, it is cached for future use.
16+
/// Once all `CassCluster` instances that reference a runtime are dropped,
17+
/// the runtime is also dropped.
18+
pub(crate) struct Runtimes {
19+
// Weak pointers are used to make runtimes dropped once all CassCluster instances
20+
// that reference them are freed.
21+
default_runtime: Option<Weak<Runtime>>,
22+
// This is Option to allow creating a static instance of Runtimes.
23+
// (`HashMap::new` is not `const`).
24+
n_thread_runtimes: Option<HashMap<usize, Weak<Runtime>>>,
25+
}
26+
27+
pub(crate) static RUNTIMES: std::sync::Mutex<Runtimes> = {
28+
std::sync::Mutex::new(Runtimes {
29+
default_runtime: None,
30+
n_thread_runtimes: None,
31+
})
32+
};
33+
34+
impl Runtimes {
35+
fn cached_or_new_runtime(
36+
weak_runtime: &mut Weak<Runtime>,
37+
create_runtime: impl FnOnce() -> Result<Arc<Runtime>, std::io::Error>,
38+
) -> Result<Arc<Runtime>, std::io::Error> {
39+
match weak_runtime.upgrade() {
40+
Some(cached_runtime) => Ok(cached_runtime),
41+
None => {
42+
let runtime = create_runtime()?;
43+
*weak_runtime = Arc::downgrade(&runtime);
44+
Ok(runtime)
45+
}
46+
}
47+
}
48+
49+
/// Returns a default tokio runtime.
50+
///
51+
/// If it's not created yet, it will create a new one with the default configuration
52+
/// and cache it for future use.
53+
pub(crate) fn default_runtime(&mut self) -> Result<Arc<Runtime>, std::io::Error> {
54+
let default_runtime_slot = self.default_runtime.get_or_insert_with(Weak::new);
55+
Self::cached_or_new_runtime(default_runtime_slot, || Runtime::new().map(Arc::new))
56+
}
57+
58+
/// Returns a tokio runtime with `n_threads` worker threads.
59+
///
60+
/// If it's not created yet, it will create a new one and cache it for future use.
61+
pub(crate) fn n_thread_runtime(
62+
&mut self,
63+
n_threads: usize,
64+
) -> Result<Arc<Runtime>, std::io::Error> {
65+
let n_thread_runtimes = self.n_thread_runtimes.get_or_insert_with(HashMap::new);
66+
let n_thread_runtime_slot = n_thread_runtimes.entry(n_threads).or_default();
67+
68+
Self::cached_or_new_runtime(n_thread_runtime_slot, || {
69+
match n_threads {
70+
0 => tokio::runtime::Builder::new_current_thread()
71+
.enable_all()
72+
.build(),
73+
n => tokio::runtime::Builder::new_multi_thread()
74+
.worker_threads(n)
75+
.enable_all()
76+
.build(),
77+
}
78+
.map(Arc::new)
79+
})
80+
}
81+
}

scylla-rust-wrapper/src/session.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl CassConnectedSession {
7979
let cluster_client_id = cluster.get_client_id();
8080

8181
let fut = Self::connect_fut(
82-
Arc::clone(cluster.get_runtime()),
82+
cluster.get_runtime(),
8383
session,
8484
session_builder,
8585
cluster_client_id,
@@ -89,7 +89,7 @@ impl CassConnectedSession {
8989
);
9090

9191
CassFuture::make_raw(
92-
Arc::clone(cluster.get_runtime()),
92+
cluster.get_runtime(),
9393
fut,
9494
#[cfg(cpp_integration_testing)]
9595
None,

0 commit comments

Comments
 (0)