@@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint;
12
12
use crate :: cluster:: node:: { InternalKnownNode , KnownNode , NodeRef } ;
13
13
use crate :: cluster:: { Cluster , ClusterNeatDebug , ClusterState } ;
14
14
use crate :: errors:: {
15
- BadQuery , BrokenConnectionError , ExecutionError , MetadataError , NewSessionError ,
16
- PagerExecutionError , PrepareError , RequestAttemptError , RequestError , SchemaAgreementError ,
17
- TracingError , UseKeyspaceError ,
15
+ BrokenConnectionError , ExecutionError , MetadataError , NewSessionError , PagerExecutionError ,
16
+ PrepareError , RequestAttemptError , RequestError , SchemaAgreementError , TracingError ,
17
+ UseKeyspaceError ,
18
18
} ;
19
19
use crate :: frame:: response:: result;
20
20
use crate :: network:: tls:: TlsProvider ;
@@ -36,7 +36,7 @@ use crate::response::{
36
36
} ;
37
37
use crate :: routing:: partitioner:: PartitionerName ;
38
38
use crate :: routing:: { Shard , ShardAwarePortRange } ;
39
- use crate :: statement:: batch:: batch_values ;
39
+ use crate :: statement:: batch:: BoundBatch ;
40
40
use crate :: statement:: batch:: { Batch , BatchStatement } ;
41
41
use crate :: statement:: bound:: BoundStatement ;
42
42
use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
@@ -47,9 +47,10 @@ use futures::future::join_all;
47
47
use futures:: future:: try_join_all;
48
48
use itertools:: Itertools ;
49
49
use scylla_cql:: frame:: response:: NonErrorResponse ;
50
- use scylla_cql:: serialize:: batch:: BatchValues ;
50
+ use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
51
51
use scylla_cql:: serialize:: row:: SerializeRow ;
52
- use std:: borrow:: Borrow ;
52
+ use std:: borrow:: { Borrow , Cow } ;
53
+ use std:: collections:: { HashMap , HashSet } ;
53
54
use std:: future:: Future ;
54
55
use std:: net:: { IpAddr , SocketAddr } ;
55
56
use std:: num:: NonZeroU32 ;
@@ -809,7 +810,10 @@ impl Session {
809
810
batch : & Batch ,
810
811
values : impl BatchValues ,
811
812
) -> Result < QueryResult , ExecutionError > {
812
- self . do_batch ( batch, values) . await
813
+ let batch = self . last_minute_prepare_batch ( batch, & values) . await ?;
814
+ let batch = BoundBatch :: from_batch ( batch. as_ref ( ) , values) ?;
815
+
816
+ self . do_batch ( & batch) . await
813
817
}
814
818
815
819
/// Estabilishes a CQL session with the database
@@ -1185,7 +1189,10 @@ impl Session {
1185
1189
// Making QueryPager::new_for_query work with values is too hard (if even possible)
1186
1190
// so instead of sending one prepare to a specific connection on each iterator query,
1187
1191
// we fully prepare a statement beforehand.
1188
- let bound = self . prepare_nongeneric ( & statement) . await ?. into_bind ( & values) ?;
1192
+ let bound = self
1193
+ . prepare_nongeneric ( & statement)
1194
+ . await ?
1195
+ . into_bind ( & values) ?;
1189
1196
QueryPager :: new_for_prepared_statement ( PreparedPagerConfig {
1190
1197
bound,
1191
1198
execution_profile,
@@ -1509,22 +1516,9 @@ impl Session {
1509
1516
. map_err ( PagerExecutionError :: NextPageError )
1510
1517
}
1511
1518
1512
- async fn do_batch (
1513
- & self ,
1514
- batch : & Batch ,
1515
- values : impl BatchValues ,
1516
- ) -> Result < QueryResult , ExecutionError > {
1519
+ async fn do_batch ( & self , batch : & BoundBatch ) -> Result < QueryResult , ExecutionError > {
1517
1520
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
1518
1521
// If users batch statements by shard, they will be rewarded with full shard awareness
1519
-
1520
- // check to ensure that we don't send a batch statement with more than u16::MAX queries
1521
- let batch_statements_length = batch. statements . len ( ) ;
1522
- if batch_statements_length > u16:: MAX as usize {
1523
- return Err ( ExecutionError :: BadQuery (
1524
- BadQuery :: TooManyQueriesInBatchStatement ( batch_statements_length) ,
1525
- ) ) ;
1526
- }
1527
-
1528
1522
let execution_profile = batch
1529
1523
. get_execution_profile_handle ( )
1530
1524
. unwrap_or_else ( || self . get_default_execution_profile_handle ( ) )
@@ -1540,22 +1534,17 @@ impl Session {
1540
1534
. serial_consistency
1541
1535
. unwrap_or ( execution_profile. serial_consistency ) ;
1542
1536
1543
- let ( first_value_token, values) =
1544
- batch_values:: peek_first_token ( values, batch. statements . first ( ) ) ?;
1545
- let values_ref = & values;
1546
-
1547
- let table_spec =
1548
- if let Some ( BatchStatement :: PreparedStatement ( ps) ) = batch. statements . first ( ) {
1549
- ps. get_table_spec ( )
1550
- } else {
1551
- None
1552
- } ;
1537
+ let ( table, token) = batch
1538
+ . first_prepared
1539
+ . as_ref ( )
1540
+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
1541
+ . unzip ( ) ;
1553
1542
1554
1543
let statement_info = RoutingInfo {
1555
1544
consistency,
1556
1545
serial_consistency,
1557
- token : first_value_token ,
1558
- table : table_spec ,
1546
+ token,
1547
+ table,
1559
1548
is_confirmed_lwt : false ,
1560
1549
} ;
1561
1550
@@ -1578,12 +1567,7 @@ impl Session {
1578
1567
. unwrap_or ( execution_profile. serial_consistency ) ;
1579
1568
async move {
1580
1569
connection
1581
- . batch_with_consistency (
1582
- batch,
1583
- values_ref,
1584
- consistency,
1585
- serial_consistency,
1586
- )
1570
+ . batch_with_consistency ( batch, consistency, serial_consistency)
1587
1571
. await
1588
1572
. and_then ( QueryResponse :: into_non_error_query_response)
1589
1573
}
@@ -1652,6 +1636,54 @@ impl Session {
1652
1636
Ok ( prepared_batch)
1653
1637
}
1654
1638
1639
+ async fn last_minute_prepare_batch < ' b > (
1640
+ & self ,
1641
+ init_batch : & ' b Batch ,
1642
+ values : impl BatchValues ,
1643
+ ) -> Result < Cow < ' b , Batch > , PrepareError > {
1644
+ let mut to_prepare = HashSet :: < & str > :: new ( ) ;
1645
+
1646
+ {
1647
+ let mut values_iter = values. batch_values_iter ( ) ;
1648
+ for stmt in & init_batch. statements {
1649
+ if let BatchStatement :: Query ( query) = stmt {
1650
+ if let Some ( false ) = values_iter. is_empty_next ( ) {
1651
+ to_prepare. insert ( & query. contents ) ;
1652
+ }
1653
+ } else {
1654
+ values_iter. skip_next ( ) ;
1655
+ }
1656
+ }
1657
+ }
1658
+
1659
+ if to_prepare. is_empty ( ) {
1660
+ return Ok ( Cow :: Borrowed ( init_batch) ) ;
1661
+ }
1662
+
1663
+ let mut prepared_queries = HashMap :: < & str , PreparedStatement > :: new ( ) ;
1664
+
1665
+ for query in to_prepare {
1666
+ let prepared = self . prepare ( query) . await ?;
1667
+ prepared_queries. insert ( query, prepared) ;
1668
+ }
1669
+
1670
+ let mut batch: Cow < Batch > = Cow :: Owned ( Batch :: new_from ( init_batch) ) ;
1671
+ for stmt in & init_batch. statements {
1672
+ match stmt {
1673
+ BatchStatement :: Query ( query) => match prepared_queries. get ( query. contents . as_str ( ) )
1674
+ {
1675
+ Some ( prepared) => batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ,
1676
+ None => batch. to_mut ( ) . append_statement ( query. clone ( ) ) ,
1677
+ } ,
1678
+ BatchStatement :: PreparedStatement ( prepared) => {
1679
+ batch. to_mut ( ) . append_statement ( prepared. clone ( ) ) ;
1680
+ }
1681
+ }
1682
+ }
1683
+
1684
+ Ok ( batch)
1685
+ }
1686
+
1655
1687
/// Sends `USE <keyspace_name>` request on all connections\
1656
1688
/// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
1657
1689
///
0 commit comments