@@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint;
1212use crate :: cluster:: node:: { InternalKnownNode , KnownNode , NodeRef } ;
1313use crate :: cluster:: { Cluster , ClusterNeatDebug , ClusterState } ;
1414use crate :: errors:: {
15- BadQuery , BrokenConnectionError , ExecutionError , MetadataError , NewSessionError ,
16- PagerExecutionError , PrepareError , RequestAttemptError , RequestError , SchemaAgreementError ,
17- TracingError , UseKeyspaceError ,
15+ BrokenConnectionError , ExecutionError , MetadataError , NewSessionError , PagerExecutionError ,
16+ PrepareError , RequestAttemptError , RequestError , SchemaAgreementError , TracingError ,
17+ UseKeyspaceError ,
1818} ;
1919use crate :: frame:: response:: result;
2020use crate :: network:: tls:: TlsProvider ;
@@ -36,7 +36,7 @@ use crate::response::{
3636} ;
3737use crate :: routing:: partitioner:: PartitionerName ;
3838use crate :: routing:: { Shard , ShardAwarePortRange } ;
39- use crate :: statement:: batch:: batch_values ;
39+ use crate :: statement:: batch:: BoundBatch ;
4040use crate :: statement:: batch:: { Batch , BatchStatement } ;
4141use crate :: statement:: bound:: BoundStatement ;
4242use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
@@ -47,9 +47,10 @@ use futures::future::join_all;
4747use futures:: future:: try_join_all;
4848use itertools:: Itertools ;
4949use scylla_cql:: frame:: response:: NonErrorResponse ;
50- use scylla_cql:: serialize:: batch:: BatchValues ;
50+ use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
5151use scylla_cql:: serialize:: row:: SerializeRow ;
52- use std:: borrow:: Borrow ;
52+ use std:: borrow:: { Borrow , Cow } ;
53+ use std:: collections:: { HashMap , HashSet } ;
5354use std:: future:: Future ;
5455use std:: net:: { IpAddr , SocketAddr } ;
5556use std:: num:: NonZeroU32 ;
@@ -809,7 +810,10 @@ impl Session {
809810 batch : & Batch ,
810811 values : impl BatchValues ,
811812 ) -> Result < QueryResult , ExecutionError > {
812- self . do_batch ( batch, values) . await
813+ let batch = self . last_minute_prepare_batch ( batch, & values) . await ?;
814+ let batch = BoundBatch :: from_batch ( batch. as_ref ( ) , values) ?;
815+
816+ self . do_batch ( & batch) . await
813817 }
814818
815819 /// Estabilishes a CQL session with the database
@@ -1189,7 +1193,10 @@ impl Session {
11891193 // Making QueryPager::new_for_query work with values is too hard (if even possible)
11901194 // so instead of sending one prepare to a specific connection on each iterator query,
11911195 // we fully prepare a statement beforehand.
1192- let bound = self . prepare_nongeneric ( & statement) . await ?. into_bind ( & values) ?;
1196+ let bound = self
1197+ . prepare_nongeneric ( & statement)
1198+ . await ?
1199+ . into_bind ( & values) ?;
11931200 QueryPager :: new_for_prepared_statement ( PreparedPagerConfig {
11941201 bound,
11951202 execution_profile,
@@ -1517,22 +1524,9 @@ impl Session {
15171524 . map_err ( PagerExecutionError :: NextPageError )
15181525 }
15191526
1520- async fn do_batch (
1521- & self ,
1522- batch : & Batch ,
1523- values : impl BatchValues ,
1524- ) -> Result < QueryResult , ExecutionError > {
1527+ async fn do_batch ( & self , batch : & BoundBatch ) -> Result < QueryResult , ExecutionError > {
15251528 // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
15261529 // If users batch statements by shard, they will be rewarded with full shard awareness
1527-
1528- // check to ensure that we don't send a batch statement with more than u16::MAX queries
1529- let batch_statements_length = batch. statements . len ( ) ;
1530- if batch_statements_length > u16:: MAX as usize {
1531- return Err ( ExecutionError :: BadQuery (
1532- BadQuery :: TooManyQueriesInBatchStatement ( batch_statements_length) ,
1533- ) ) ;
1534- }
1535-
15361530 let execution_profile = batch
15371531 . get_execution_profile_handle ( )
15381532 . unwrap_or_else ( || self . get_default_execution_profile_handle ( ) )
@@ -1548,22 +1542,17 @@ impl Session {
15481542 . serial_consistency
15491543 . unwrap_or ( execution_profile. serial_consistency ) ;
15501544
1551- let ( first_value_token, values) =
1552- batch_values:: peek_first_token ( values, batch. statements . first ( ) ) ?;
1553- let values_ref = & values;
1554-
1555- let table_spec =
1556- if let Some ( BatchStatement :: PreparedStatement ( ps) ) = batch. statements . first ( ) {
1557- ps. get_table_spec ( )
1558- } else {
1559- None
1560- } ;
1545+ let ( table, token) = batch
1546+ . first_prepared
1547+ . as_ref ( )
1548+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
1549+ . unzip ( ) ;
15611550
15621551 let statement_info = RoutingInfo {
15631552 consistency,
15641553 serial_consistency,
1565- token : first_value_token ,
1566- table : table_spec ,
1554+ token,
1555+ table,
15671556 is_confirmed_lwt : false ,
15681557 } ;
15691558
@@ -1586,12 +1575,7 @@ impl Session {
15861575 . unwrap_or ( execution_profile. serial_consistency ) ;
15871576 async move {
15881577 connection
1589- . batch_with_consistency (
1590- batch,
1591- values_ref,
1592- consistency,
1593- serial_consistency,
1594- )
1578+ . batch_with_consistency ( batch, consistency, serial_consistency)
15951579 . await
15961580 . and_then ( QueryResponse :: into_non_error_query_response)
15971581 }
@@ -1660,6 +1644,54 @@ impl Session {
16601644 Ok ( prepared_batch)
16611645 }
16621646
1647+ async fn last_minute_prepare_batch < ' b > (
1648+ & self ,
1649+ init_batch : & ' b Batch ,
1650+ values : impl BatchValues ,
1651+ ) -> Result < Cow < ' b , Batch > , PrepareError > {
1652+ let mut to_prepare = HashSet :: < & str > :: new ( ) ;
1653+
1654+ {
1655+ let mut values_iter = values. batch_values_iter ( ) ;
1656+ for stmt in & init_batch. statements {
1657+ if let BatchStatement :: Query ( query) = stmt {
1658+ if let Some ( false ) = values_iter. is_empty_next ( ) {
1659+ to_prepare. insert ( & query. contents ) ;
1660+ }
1661+ } else {
1662+ values_iter. skip_next ( ) ;
1663+ }
1664+ }
1665+ }
1666+
1667+ if to_prepare. is_empty ( ) {
1668+ return Ok ( Cow :: Borrowed ( init_batch) ) ;
1669+ }
1670+
1671+ let mut prepared_queries = HashMap :: < & str , PreparedStatement > :: new ( ) ;
1672+
1673+ for query in to_prepare {
1674+ let prepared = self . prepare ( query) . await ?;
1675+ prepared_queries. insert ( query, prepared) ;
1676+ }
1677+
1678+ let mut batch: Cow < Batch > = Cow :: Owned ( Batch :: new_from ( init_batch) ) ;
1679+ for stmt in & init_batch. statements {
1680+ match stmt {
1681+ BatchStatement :: Query ( query) => match prepared_queries. get ( query. contents . as_str ( ) )
1682+ {
1683+ Some ( prepared) => batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ,
1684+ None => batch. to_mut ( ) . append_statement ( query. clone ( ) ) ,
1685+ } ,
1686+ BatchStatement :: PreparedStatement ( prepared) => {
1687+ batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ;
1688+ }
1689+ }
1690+ }
1691+
1692+ Ok ( batch)
1693+ }
1694+
16631695 /// Sends `USE <keyspace_name>` request on all connections\
16641696 /// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
16651697 ///
0 commit comments