@@ -5,7 +5,7 @@ use crate::statement::{PagingState, PagingStateResponse};
55use crate :: transport:: errors:: QueryError ;
66use crate :: transport:: iterator:: LegacyRowIterator ;
77use crate :: transport:: partitioner:: PartitionerName ;
8- use crate :: { LegacyQueryResult , LegacySession } ;
8+ use crate :: { LegacyQueryResult , QueryResult } ;
99use bytes:: Bytes ;
1010use dashmap:: DashMap ;
1111use futures:: future:: try_join_all;
@@ -16,6 +16,11 @@ use std::collections::hash_map::RandomState;
1616use std:: hash:: BuildHasher ;
1717use std:: sync:: Arc ;
1818
19+ use super :: iterator:: QueryPager ;
20+ use super :: session:: {
21+ CurrentDeserializationApi , DeserializationApiKind , GenericSession , LegacyDeserializationApi ,
22+ } ;
23+
1924/// Contains just the parts of a prepared statement that were returned
2025/// from the database. All remaining parts (query string, page size,
2126/// consistency, etc.) are taken from the Query passed
@@ -31,23 +36,28 @@ struct RawPreparedStatementData {
3136
3237/// Provides auto caching while executing queries
3338#[ derive( Debug ) ]
34- pub struct CachingSession < S = RandomState >
39+ pub struct GenericCachingSession < DeserializationApi , S = RandomState >
3540where
3641 S : Clone + BuildHasher ,
42+ DeserializationApi : DeserializationApiKind ,
3743{
38- session : LegacySession ,
44+ session : GenericSession < DeserializationApi > ,
3945 /// The prepared statement cache size
4046 /// If a prepared statement is added while the limit is reached, the oldest prepared statement
4147 /// is removed from the cache
4248 max_capacity : usize ,
4349 cache : DashMap < String , RawPreparedStatementData , S > ,
4450}
4551
46- impl < S > CachingSession < S >
52+ pub type CachingSession < S = RandomState > = GenericCachingSession < CurrentDeserializationApi , S > ;
53+ pub type LegacyCachingSession < S = RandomState > = GenericCachingSession < LegacyDeserializationApi , S > ;
54+
55+ impl < DeserApi , S > GenericCachingSession < DeserApi , S >
4756where
4857 S : Default + BuildHasher + Clone ,
58+ DeserApi : DeserializationApiKind ,
4959{
50- pub fn from ( session : LegacySession , cache_size : usize ) -> Self {
60+ pub fn from ( session : GenericSession < DeserApi > , cache_size : usize ) -> Self {
5161 Self {
5262 session,
5363 max_capacity : cache_size,
@@ -56,20 +66,88 @@ where
5666 }
5767}
5868
59- impl < S > CachingSession < S >
69+ impl < DeserApi , S > GenericCachingSession < DeserApi , S >
6070where
6171 S : BuildHasher + Clone ,
72+ DeserApi : DeserializationApiKind ,
6273{
6374 /// Builds a [`CachingSession`] from a [`Session`], a cache size, and a [`BuildHasher`].,
6475 /// using a customer hasher.
65- pub fn with_hasher ( session : LegacySession , cache_size : usize , hasher : S ) -> Self {
76+ pub fn with_hasher ( session : GenericSession < DeserApi > , cache_size : usize , hasher : S ) -> Self {
6677 Self {
6778 session,
6879 max_capacity : cache_size,
6980 cache : DashMap :: with_hasher ( hasher) ,
7081 }
7182 }
83+ }
7284
85+ impl < S > GenericCachingSession < CurrentDeserializationApi , S >
86+ where
87+ S : BuildHasher + Clone ,
88+ {
89+ /// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
90+ pub async fn execute_unpaged (
91+ & self ,
92+ query : impl Into < Query > ,
93+ values : impl SerializeRow ,
94+ ) -> Result < QueryResult , QueryError > {
95+ let query = query. into ( ) ;
96+ let prepared = self . add_prepared_statement_owned ( query) . await ?;
97+ self . session . execute_unpaged ( & prepared, values) . await
98+ }
99+
100+ /// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache
101+ pub async fn execute_iter (
102+ & self ,
103+ query : impl Into < Query > ,
104+ values : impl SerializeRow ,
105+ ) -> Result < QueryPager , QueryError > {
106+ let query = query. into ( ) ;
107+ let prepared = self . add_prepared_statement_owned ( query) . await ?;
108+ self . session . execute_iter ( prepared, values) . await
109+ }
110+
111+ /// Does the same thing as [`Session::execute_single_page`] but uses the prepared statement cache
112+ pub async fn execute_single_page (
113+ & self ,
114+ query : impl Into < Query > ,
115+ values : impl SerializeRow ,
116+ paging_state : PagingState ,
117+ ) -> Result < ( QueryResult , PagingStateResponse ) , QueryError > {
118+ let query = query. into ( ) ;
119+ let prepared = self . add_prepared_statement_owned ( query) . await ?;
120+ self . session
121+ . execute_single_page ( & prepared, values, paging_state)
122+ . await
123+ }
124+
125+ /// Does the same thing as [`Session::batch`] but uses the prepared statement cache\
126+ /// Prepares batch using CachingSession::prepare_batch if needed and then executes it
127+ pub async fn batch (
128+ & self ,
129+ batch : & Batch ,
130+ values : impl BatchValues ,
131+ ) -> Result < QueryResult , QueryError > {
132+ let all_prepared: bool = batch
133+ . statements
134+ . iter ( )
135+ . all ( |stmt| matches ! ( stmt, BatchStatement :: PreparedStatement ( _) ) ) ;
136+
137+ if all_prepared {
138+ self . session . batch ( batch, & values) . await
139+ } else {
140+ let prepared_batch: Batch = self . prepare_batch ( batch) . await ?;
141+
142+ self . session . batch ( & prepared_batch, & values) . await
143+ }
144+ }
145+ }
146+
147+ impl < S > GenericCachingSession < LegacyDeserializationApi , S >
148+ where
149+ S : BuildHasher + Clone ,
150+ {
73151 /// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
74152 pub async fn execute_unpaged (
75153 & self ,
@@ -126,7 +204,13 @@ where
126204 self . session . batch ( & prepared_batch, & values) . await
127205 }
128206 }
207+ }
129208
209+ impl < DeserApi , S > GenericCachingSession < DeserApi , S >
210+ where
211+ S : BuildHasher + Clone ,
212+ DeserApi : DeserializationApiKind ,
213+ {
130214 /// Prepares all statements within the batch and returns a new batch where every
131215 /// statement is prepared.
132216 /// Uses the prepared statements cache.
@@ -212,7 +296,7 @@ where
212296 self . max_capacity
213297 }
214298
215- pub fn get_session ( & self ) -> & LegacySession {
299+ pub fn get_session ( & self ) -> & GenericSession < DeserApi > {
216300 & self . session
217301 }
218302}
@@ -229,7 +313,7 @@ mod tests {
229313 use crate :: {
230314 batch:: { Batch , BatchStatement } ,
231315 prepared_statement:: PreparedStatement ,
232- CachingSession , LegacySession ,
316+ LegacyCachingSession , LegacySession ,
233317 } ;
234318 use futures:: TryStreamExt ;
235319 use std:: collections:: BTreeSet ;
@@ -273,8 +357,8 @@ mod tests {
273357 session
274358 }
275359
276- async fn create_caching_session ( ) -> CachingSession {
277- let session = CachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
360+ async fn create_caching_session ( ) -> LegacyCachingSession {
361+ let session = LegacyCachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
278362
279363 // Add a row, this makes it easier to check if the caching works combined with the regular execute fn on Session
280364 session
@@ -385,7 +469,7 @@ mod tests {
385469 }
386470
387471 async fn assert_test_batch_table_rows_contain (
388- sess : & CachingSession ,
472+ sess : & LegacyCachingSession ,
389473 expected_rows : & [ ( i32 , i32 ) ] ,
390474 ) {
391475 let selected_rows: BTreeSet < ( i32 , i32 ) > = sess
@@ -431,18 +515,18 @@ mod tests {
431515 }
432516 }
433517
434- let _session: CachingSession < std:: collections:: hash_map:: RandomState > =
435- CachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
436- let _session: CachingSession < CustomBuildHasher > =
437- CachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
438- let _session: CachingSession < CustomBuildHasher > =
439- CachingSession :: with_hasher ( new_for_test ( true ) . await , 2 , Default :: default ( ) ) ;
518+ let _session: LegacyCachingSession < std:: collections:: hash_map:: RandomState > =
519+ LegacyCachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
520+ let _session: LegacyCachingSession < CustomBuildHasher > =
521+ LegacyCachingSession :: from ( new_for_test ( true ) . await , 2 ) ;
522+ let _session: LegacyCachingSession < CustomBuildHasher > =
523+ LegacyCachingSession :: with_hasher ( new_for_test ( true ) . await , 2 , Default :: default ( ) ) ;
440524 }
441525
442526 #[ tokio:: test]
443527 async fn test_batch ( ) {
444528 setup_tracing ( ) ;
445- let session: CachingSession = create_caching_session ( ) . await ;
529+ let session: LegacyCachingSession = create_caching_session ( ) . await ;
446530
447531 session
448532 . execute_unpaged (
@@ -565,7 +649,8 @@ mod tests {
565649 #[ tokio:: test]
566650 async fn test_parameters_caching ( ) {
567651 setup_tracing ( ) ;
568- let session: CachingSession = CachingSession :: from ( new_for_test ( true ) . await , 100 ) ;
652+ let session: LegacyCachingSession =
653+ LegacyCachingSession :: from ( new_for_test ( true ) . await , 100 ) ;
569654
570655 session
571656 . execute_unpaged ( "CREATE TABLE tbl (a int PRIMARY KEY, b int)" , ( ) )
@@ -618,7 +703,8 @@ mod tests {
618703 }
619704
620705 // This test uses CDC which is not yet compatible with Scylla's tablets.
621- let session: CachingSession = CachingSession :: from ( new_for_test ( false ) . await , 100 ) ;
706+ let session: LegacyCachingSession =
707+ LegacyCachingSession :: from ( new_for_test ( false ) . await , 100 ) ;
622708
623709 session
624710 . execute_unpaged (
0 commit comments