@@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint;
1212use crate :: cluster:: node:: { InternalKnownNode , KnownNode , NodeRef } ;
1313use crate :: cluster:: { Cluster , ClusterNeatDebug , ClusterState } ;
1414use crate :: errors:: {
15- BadQuery , BrokenConnectionError , ExecutionError , MetadataError , NewSessionError ,
16- PagerExecutionError , PrepareError , RequestAttemptError , RequestError , SchemaAgreementError ,
17- TracingError , UseKeyspaceError ,
15+ BrokenConnectionError , ExecutionError , MetadataError , NewSessionError , PagerExecutionError ,
16+ PrepareError , RequestAttemptError , RequestError , SchemaAgreementError , TracingError ,
17+ UseKeyspaceError ,
1818} ;
1919use crate :: frame:: response:: result;
2020use crate :: network:: tls:: TlsProvider ;
@@ -36,7 +36,7 @@ use crate::response::{
3636} ;
3737use crate :: routing:: partitioner:: PartitionerName ;
3838use crate :: routing:: { Shard , ShardAwarePortRange } ;
39- use crate :: statement:: batch:: batch_values ;
39+ use crate :: statement:: batch:: BoundBatch ;
4040use crate :: statement:: batch:: { Batch , BatchStatement } ;
4141use crate :: statement:: bound:: BoundStatement ;
4242use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
@@ -47,9 +47,10 @@ use futures::future::join_all;
4747use futures:: future:: try_join_all;
4848use itertools:: Itertools ;
4949use scylla_cql:: frame:: response:: NonErrorResponse ;
50- use scylla_cql:: serialize:: batch:: BatchValues ;
50+ use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
5151use scylla_cql:: serialize:: row:: SerializeRow ;
52- use std:: borrow:: Borrow ;
52+ use std:: borrow:: { Borrow , Cow } ;
53+ use std:: collections:: { HashMap , HashSet } ;
5354use std:: future:: Future ;
5455use std:: net:: { IpAddr , SocketAddr } ;
5556use std:: num:: NonZeroU32 ;
@@ -809,7 +810,10 @@ impl Session {
809810 batch : & Batch ,
810811 values : impl BatchValues ,
811812 ) -> Result < QueryResult , ExecutionError > {
812- self . do_batch ( batch, values) . await
813+ let batch = self . last_minute_prepare_batch ( batch, & values) . await ?;
814+ let batch = BoundBatch :: from_batch ( batch. as_ref ( ) , values) ?;
815+
816+ self . do_batch ( & batch) . await
813817 }
814818
815819 /// Estabilishes a CQL session with the database
@@ -1185,7 +1189,10 @@ impl Session {
11851189 // Making QueryPager::new_for_query work with values is too hard (if even possible)
11861190 // so instead of sending one prepare to a specific connection on each iterator query,
11871191 // we fully prepare a statement beforehand.
1188- let bound = self . prepare_nongeneric ( & statement) . await ?. into_bind ( & values) ?;
1192+ let bound = self
1193+ . prepare_nongeneric ( & statement)
1194+ . await ?
1195+ . into_bind ( & values) ?;
11891196 QueryPager :: new_for_prepared_statement ( PreparedPagerConfig {
11901197 bound,
11911198 execution_profile,
@@ -1509,22 +1516,9 @@ impl Session {
15091516 . map_err ( PagerExecutionError :: NextPageError )
15101517 }
15111518
1512- async fn do_batch (
1513- & self ,
1514- batch : & Batch ,
1515- values : impl BatchValues ,
1516- ) -> Result < QueryResult , ExecutionError > {
1519+ async fn do_batch ( & self , batch : & BoundBatch ) -> Result < QueryResult , ExecutionError > {
15171520 // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
15181521 // If users batch statements by shard, they will be rewarded with full shard awareness
1519-
1520- // check to ensure that we don't send a batch statement with more than u16::MAX queries
1521- let batch_statements_length = batch. statements . len ( ) ;
1522- if batch_statements_length > u16:: MAX as usize {
1523- return Err ( ExecutionError :: BadQuery (
1524- BadQuery :: TooManyQueriesInBatchStatement ( batch_statements_length) ,
1525- ) ) ;
1526- }
1527-
15281522 let execution_profile = batch
15291523 . get_execution_profile_handle ( )
15301524 . unwrap_or_else ( || self . get_default_execution_profile_handle ( ) )
@@ -1540,22 +1534,17 @@ impl Session {
15401534 . serial_consistency
15411535 . unwrap_or ( execution_profile. serial_consistency ) ;
15421536
1543- let ( first_value_token, values) =
1544- batch_values:: peek_first_token ( values, batch. statements . first ( ) ) ?;
1545- let values_ref = & values;
1546-
1547- let table_spec =
1548- if let Some ( BatchStatement :: PreparedStatement ( ps) ) = batch. statements . first ( ) {
1549- ps. get_table_spec ( )
1550- } else {
1551- None
1552- } ;
1537+ let ( table, token) = batch
1538+ . first_prepared
1539+ . as_ref ( )
1540+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
1541+ . unzip ( ) ;
15531542
15541543 let statement_info = RoutingInfo {
15551544 consistency,
15561545 serial_consistency,
1557- token : first_value_token ,
1558- table : table_spec ,
1546+ token,
1547+ table,
15591548 is_confirmed_lwt : false ,
15601549 } ;
15611550
@@ -1578,12 +1567,7 @@ impl Session {
15781567 . unwrap_or ( execution_profile. serial_consistency ) ;
15791568 async move {
15801569 connection
1581- . batch_with_consistency (
1582- batch,
1583- values_ref,
1584- consistency,
1585- serial_consistency,
1586- )
1570+ . batch_with_consistency ( batch, consistency, serial_consistency)
15871571 . await
15881572 . and_then ( QueryResponse :: into_non_error_query_response)
15891573 }
@@ -1652,6 +1636,54 @@ impl Session {
16521636 Ok ( prepared_batch)
16531637 }
16541638
1639+ async fn last_minute_prepare_batch < ' b > (
1640+ & self ,
1641+ init_batch : & ' b Batch ,
1642+ values : impl BatchValues ,
1643+ ) -> Result < Cow < ' b , Batch > , PrepareError > {
1644+ let mut to_prepare = HashSet :: < & str > :: new ( ) ;
1645+
1646+ {
1647+ let mut values_iter = values. batch_values_iter ( ) ;
1648+ for stmt in & init_batch. statements {
1649+ if let BatchStatement :: Query ( query) = stmt {
1650+ if let Some ( false ) = values_iter. is_empty_next ( ) {
1651+ to_prepare. insert ( & query. contents ) ;
1652+ }
1653+ } else {
1654+ values_iter. skip_next ( ) ;
1655+ }
1656+ }
1657+ }
1658+
1659+ if to_prepare. is_empty ( ) {
1660+ return Ok ( Cow :: Borrowed ( init_batch) ) ;
1661+ }
1662+
1663+ let mut prepared_queries = HashMap :: < & str , PreparedStatement > :: new ( ) ;
1664+
1665+ for query in to_prepare {
1666+ let prepared = self . prepare ( query) . await ?;
1667+ prepared_queries. insert ( query, prepared) ;
1668+ }
1669+
1670+ let mut batch: Cow < Batch > = Cow :: Owned ( Batch :: new_from ( init_batch) ) ;
1671+ for stmt in & init_batch. statements {
1672+ match stmt {
1673+ BatchStatement :: Query ( query) => match prepared_queries. get ( query. contents . as_str ( ) )
1674+ {
1675+ Some ( prepared) => batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ,
1676+ None => batch. to_mut ( ) . append_statement ( query. clone ( ) ) ,
1677+ } ,
1678+ BatchStatement :: PreparedStatement ( prepared) => {
1679+ batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ;
1680+ }
1681+ }
1682+ }
1683+
1684+ Ok ( batch)
1685+ }
1686+
16551687 /// Sends `USE <keyspace_name>` request on all connections\
16561688 /// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
16571689 ///
0 commit comments