Skip to content

Commit b3f4a04

Browse files
piodulwprzytula
andcommitted
caching_session: make generic over session APIs
In a similar fashion to Session, CachingSession was also made generic over the session kind. Co-authored-by: Wojciech Przytuła <[email protected]>
1 parent 2ec2885 commit b3f4a04

File tree

3 files changed

+111
-25
lines changed

3 files changed

+111
-25
lines changed

scylla/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ pub use statement::query;
256256
pub use frame::response::cql_to_rust;
257257
pub use frame::response::cql_to_rust::FromRow;
258258

259-
pub use transport::caching_session::CachingSession;
259+
pub use transport::caching_session::{CachingSession, GenericCachingSession, LegacyCachingSession};
260260
pub use transport::execution_profile::ExecutionProfile;
261261
pub use transport::legacy_query_result::LegacyQueryResult;
262262
pub use transport::query_result::{QueryResult, QueryRowsResult};

scylla/src/transport/caching_session.rs

Lines changed: 107 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::statement::{PagingState, PagingStateResponse};
55
use crate::transport::errors::QueryError;
66
use crate::transport::iterator::LegacyRowIterator;
77
use crate::transport::partitioner::PartitionerName;
8-
use crate::{LegacyQueryResult, LegacySession};
8+
use crate::{LegacyQueryResult, QueryResult};
99
use bytes::Bytes;
1010
use dashmap::DashMap;
1111
use futures::future::try_join_all;
@@ -16,6 +16,11 @@ use std::collections::hash_map::RandomState;
1616
use std::hash::BuildHasher;
1717
use std::sync::Arc;
1818

19+
use super::iterator::QueryPager;
20+
use super::session::{
21+
CurrentDeserializationApi, DeserializationApiKind, GenericSession, LegacyDeserializationApi,
22+
};
23+
1924
/// Contains just the parts of a prepared statement that were returned
2025
/// from the database. All remaining parts (query string, page size,
2126
/// consistency, etc.) are taken from the Query passed
@@ -31,23 +36,28 @@ struct RawPreparedStatementData {
3136

3237
/// Provides auto caching while executing queries
3338
#[derive(Debug)]
34-
pub struct CachingSession<S = RandomState>
39+
pub struct GenericCachingSession<DeserializationApi, S = RandomState>
3540
where
3641
S: Clone + BuildHasher,
42+
DeserializationApi: DeserializationApiKind,
3743
{
38-
session: LegacySession,
44+
session: GenericSession<DeserializationApi>,
3945
/// The prepared statement cache size
4046
/// If a prepared statement is added while the limit is reached, the oldest prepared statement
4147
/// is removed from the cache
4248
max_capacity: usize,
4349
cache: DashMap<String, RawPreparedStatementData, S>,
4450
}
4551

46-
impl<S> CachingSession<S>
52+
pub type CachingSession<S = RandomState> = GenericCachingSession<CurrentDeserializationApi, S>;
53+
pub type LegacyCachingSession<S = RandomState> = GenericCachingSession<LegacyDeserializationApi, S>;
54+
55+
impl<DeserApi, S> GenericCachingSession<DeserApi, S>
4756
where
4857
S: Default + BuildHasher + Clone,
58+
DeserApi: DeserializationApiKind,
4959
{
50-
pub fn from(session: LegacySession, cache_size: usize) -> Self {
60+
pub fn from(session: GenericSession<DeserApi>, cache_size: usize) -> Self {
5161
Self {
5262
session,
5363
max_capacity: cache_size,
@@ -56,20 +66,88 @@ where
5666
}
5767
}
5868

59-
impl<S> CachingSession<S>
69+
impl<DeserApi, S> GenericCachingSession<DeserApi, S>
6070
where
6171
S: BuildHasher + Clone,
72+
DeserApi: DeserializationApiKind,
6273
{
6374
/// Builds a [`CachingSession`] from a [`Session`], a cache size, and a [`BuildHasher`].,
6475
/// using a customer hasher.
65-
pub fn with_hasher(session: LegacySession, cache_size: usize, hasher: S) -> Self {
76+
pub fn with_hasher(session: GenericSession<DeserApi>, cache_size: usize, hasher: S) -> Self {
6677
Self {
6778
session,
6879
max_capacity: cache_size,
6980
cache: DashMap::with_hasher(hasher),
7081
}
7182
}
83+
}
7284

85+
impl<S> GenericCachingSession<CurrentDeserializationApi, S>
86+
where
87+
S: BuildHasher + Clone,
88+
{
89+
/// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
90+
pub async fn execute_unpaged(
91+
&self,
92+
query: impl Into<Query>,
93+
values: impl SerializeRow,
94+
) -> Result<QueryResult, QueryError> {
95+
let query = query.into();
96+
let prepared = self.add_prepared_statement_owned(query).await?;
97+
self.session.execute_unpaged(&prepared, values).await
98+
}
99+
100+
/// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache
101+
pub async fn execute_iter(
102+
&self,
103+
query: impl Into<Query>,
104+
values: impl SerializeRow,
105+
) -> Result<QueryPager, QueryError> {
106+
let query = query.into();
107+
let prepared = self.add_prepared_statement_owned(query).await?;
108+
self.session.execute_iter(prepared, values).await
109+
}
110+
111+
/// Does the same thing as [`Session::execute_single_page`] but uses the prepared statement cache
112+
pub async fn execute_single_page(
113+
&self,
114+
query: impl Into<Query>,
115+
values: impl SerializeRow,
116+
paging_state: PagingState,
117+
) -> Result<(QueryResult, PagingStateResponse), QueryError> {
118+
let query = query.into();
119+
let prepared = self.add_prepared_statement_owned(query).await?;
120+
self.session
121+
.execute_single_page(&prepared, values, paging_state)
122+
.await
123+
}
124+
125+
/// Does the same thing as [`Session::batch`] but uses the prepared statement cache\
126+
/// Prepares batch using CachingSession::prepare_batch if needed and then executes it
127+
pub async fn batch(
128+
&self,
129+
batch: &Batch,
130+
values: impl BatchValues,
131+
) -> Result<QueryResult, QueryError> {
132+
let all_prepared: bool = batch
133+
.statements
134+
.iter()
135+
.all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_)));
136+
137+
if all_prepared {
138+
self.session.batch(batch, &values).await
139+
} else {
140+
let prepared_batch: Batch = self.prepare_batch(batch).await?;
141+
142+
self.session.batch(&prepared_batch, &values).await
143+
}
144+
}
145+
}
146+
147+
impl<S> GenericCachingSession<LegacyDeserializationApi, S>
148+
where
149+
S: BuildHasher + Clone,
150+
{
73151
/// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
74152
pub async fn execute_unpaged(
75153
&self,
@@ -126,7 +204,13 @@ where
126204
self.session.batch(&prepared_batch, &values).await
127205
}
128206
}
207+
}
129208

209+
impl<DeserApi, S> GenericCachingSession<DeserApi, S>
210+
where
211+
S: BuildHasher + Clone,
212+
DeserApi: DeserializationApiKind,
213+
{
130214
/// Prepares all statements within the batch and returns a new batch where every
131215
/// statement is prepared.
132216
/// Uses the prepared statements cache.
@@ -212,7 +296,7 @@ where
212296
self.max_capacity
213297
}
214298

215-
pub fn get_session(&self) -> &LegacySession {
299+
pub fn get_session(&self) -> &GenericSession<DeserApi> {
216300
&self.session
217301
}
218302
}
@@ -229,7 +313,7 @@ mod tests {
229313
use crate::{
230314
batch::{Batch, BatchStatement},
231315
prepared_statement::PreparedStatement,
232-
CachingSession, LegacySession,
316+
LegacyCachingSession, LegacySession,
233317
};
234318
use futures::TryStreamExt;
235319
use std::collections::BTreeSet;
@@ -273,8 +357,8 @@ mod tests {
273357
session
274358
}
275359

276-
async fn create_caching_session() -> CachingSession {
277-
let session = CachingSession::from(new_for_test(true).await, 2);
360+
async fn create_caching_session() -> LegacyCachingSession {
361+
let session = LegacyCachingSession::from(new_for_test(true).await, 2);
278362

279363
// Add a row, this makes it easier to check if the caching works combined with the regular execute fn on Session
280364
session
@@ -385,7 +469,7 @@ mod tests {
385469
}
386470

387471
async fn assert_test_batch_table_rows_contain(
388-
sess: &CachingSession,
472+
sess: &LegacyCachingSession,
389473
expected_rows: &[(i32, i32)],
390474
) {
391475
let selected_rows: BTreeSet<(i32, i32)> = sess
@@ -431,18 +515,18 @@ mod tests {
431515
}
432516
}
433517

434-
let _session: CachingSession<std::collections::hash_map::RandomState> =
435-
CachingSession::from(new_for_test(true).await, 2);
436-
let _session: CachingSession<CustomBuildHasher> =
437-
CachingSession::from(new_for_test(true).await, 2);
438-
let _session: CachingSession<CustomBuildHasher> =
439-
CachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
518+
let _session: LegacyCachingSession<std::collections::hash_map::RandomState> =
519+
LegacyCachingSession::from(new_for_test(true).await, 2);
520+
let _session: LegacyCachingSession<CustomBuildHasher> =
521+
LegacyCachingSession::from(new_for_test(true).await, 2);
522+
let _session: LegacyCachingSession<CustomBuildHasher> =
523+
LegacyCachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
440524
}
441525

442526
#[tokio::test]
443527
async fn test_batch() {
444528
setup_tracing();
445-
let session: CachingSession = create_caching_session().await;
529+
let session: LegacyCachingSession = create_caching_session().await;
446530

447531
session
448532
.execute_unpaged(
@@ -565,7 +649,8 @@ mod tests {
565649
#[tokio::test]
566650
async fn test_parameters_caching() {
567651
setup_tracing();
568-
let session: CachingSession = CachingSession::from(new_for_test(true).await, 100);
652+
let session: LegacyCachingSession =
653+
LegacyCachingSession::from(new_for_test(true).await, 100);
569654

570655
session
571656
.execute_unpaged("CREATE TABLE tbl (a int PRIMARY KEY, b int)", ())
@@ -618,7 +703,8 @@ mod tests {
618703
}
619704

620705
// This test uses CDC which is not yet compatible with Scylla's tablets.
621-
let session: CachingSession = CachingSession::from(new_for_test(false).await, 100);
706+
let session: LegacyCachingSession =
707+
LegacyCachingSession::from(new_for_test(false).await, 100);
622708

623709
session
624710
.execute_unpaged(

scylla/src/transport/session_test.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use crate::transport::topology::{
1919
use crate::utils::test_utils::{
2020
create_new_session_builder, supports_feature, unique_keyspace_name,
2121
};
22-
use crate::CachingSession;
2322
use crate::ExecutionProfile;
23+
use crate::LegacyCachingSession;
2424
use crate::LegacyQueryResult;
2525
use crate::{LegacySession, SessionBuilder};
2626
use assert_matches::assert_matches;
@@ -2012,7 +2012,7 @@ async fn rename(session: &LegacySession, rename_str: &str) {
20122012
.unwrap();
20132013
}
20142014

2015-
async fn rename_caching(session: &CachingSession, rename_str: &str) {
2015+
async fn rename_caching(session: &LegacyCachingSession, rename_str: &str) {
20162016
session
20172017
.execute_unpaged(format!("ALTER TABLE tab RENAME {}", rename_str), &())
20182018
.await
@@ -2230,7 +2230,7 @@ async fn test_unprepared_reprepare_in_caching_session_execute() {
22302230
session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
22312231
session.use_keyspace(ks, false).await.unwrap();
22322232

2233-
let caching_session: CachingSession = CachingSession::from(session, 64);
2233+
let caching_session: LegacyCachingSession = LegacyCachingSession::from(session, 64);
22342234

22352235
caching_session
22362236
.execute_unpaged(

0 commit comments

Comments
 (0)