@@ -9,19 +9,28 @@ use scylla_cql::frame::frame_errors::{
9
9
} ;
10
10
use scylla_cql:: frame:: request;
11
11
use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
12
- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
12
+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
13
13
use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
14
+ use thiserror:: Error ;
15
+ use tracing:: Instrument ;
14
16
15
- use crate :: client:: execution_profile:: ExecutionProfileHandle ;
17
+ use crate :: client:: execution_profile:: { ExecutionProfileHandle , ExecutionProfileInner } ;
18
+ use crate :: client:: session:: { RunRequestResult , Session } ;
16
19
use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
20
+ use crate :: network:: Connection ;
21
+ use crate :: observability:: driver_tracing:: RequestSpan ;
17
22
use crate :: observability:: history:: HistoryListener ;
18
23
use crate :: policies:: load_balancing:: LoadBalancingPolicy ;
24
+ use crate :: policies:: load_balancing:: RoutingInfo ;
19
25
use crate :: policies:: retry:: RetryPolicy ;
26
+ use crate :: response:: query_result:: QueryResult ;
27
+ use crate :: response:: { Coordinator , NonErrorQueryResponse , QueryResponse } ;
20
28
use crate :: routing:: Token ;
21
29
use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
22
30
use crate :: statement:: unprepared:: Statement ;
23
31
24
32
use super :: bound:: BoundStatement ;
33
+ use super :: execute:: Execute ;
25
34
use super :: StatementConfig ;
26
35
use super :: { Consistency , SerialConsistency } ;
27
36
pub use crate :: frame:: request:: batch:: BatchType ;
@@ -282,8 +291,8 @@ pub struct BoundBatch {
282
291
batch_type : BatchType ,
283
292
pub ( crate ) buffer : Vec < u8 > ,
284
293
pub ( crate ) prepared : HashMap < Bytes , PreparedStatement > ,
285
- pub ( crate ) first_prepared : Option < ( PreparedStatement , Token ) > ,
286
- pub ( crate ) statements_len : u16 ,
294
+ first_prepared : Option < ( PreparedStatement , Token ) > ,
295
+ statements_len : u16 ,
287
296
}
288
297
289
298
impl BoundBatch {
@@ -294,6 +303,20 @@ impl BoundBatch {
294
303
}
295
304
}
296
305
306
+ /// Appends a new statement to the batch.
307
+ pub fn append_statement < ' p , V : SerializeRow > (
308
+ & mut self ,
309
+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
310
+ ) -> Result < ( ) , BoundBatchStatementError > {
311
+ let initial_len = self . buffer . len ( ) ;
312
+ self . raw_append_statement ( statement) . inspect_err ( |_| {
313
+ // if we error'd at any point we should put the buffer back to its old length to not
314
+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
315
+ // tries with a successful statement later
316
+ self . buffer . truncate ( initial_len) ;
317
+ } )
318
+ }
319
+
297
320
#[ allow( clippy:: result_large_err) ]
298
321
pub ( crate ) fn from_batch (
299
322
batch : & Batch ,
@@ -403,6 +426,96 @@ impl BoundBatch {
403
426
self . batch_type
404
427
}
405
428
429
+ pub fn statements_len ( & self ) -> u16 {
430
+ self . statements_len
431
+ }
432
+
433
+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
434
+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
435
+ // `self` to be modified if an error occured because the caller will not reset them.
436
+ fn raw_append_statement < ' p , V : SerializeRow > (
437
+ & mut self ,
438
+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
439
+ ) -> Result < ( ) , BoundBatchStatementError > {
440
+ let mut statement = statement. into ( ) ;
441
+ let mut first_prepared = None ;
442
+
443
+ if self . statements_len == 0 {
444
+ // save it into a local variable for now in case a latter steps fails
445
+ first_prepared = match statement {
446
+ BoundBatchStatement :: Bound ( ref b) => b
447
+ . token ( ) ?
448
+ . map ( |token| ( b. prepared . clone ( ) . into_owned ( ) , token) ) ,
449
+ BoundBatchStatement :: Prepared ( ps, values) => {
450
+ let bound = ps
451
+ . into_bind ( & values)
452
+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
453
+ let first_prepared = bound
454
+ . token ( ) ?
455
+ . map ( |token| ( bound. prepared . clone ( ) . into_owned ( ) , token) ) ;
456
+ // we already serialized it so to avoid re-serializing it, modify the statement to a
457
+ // BoundStatement
458
+ statement = BoundBatchStatement :: Bound ( bound) ;
459
+ first_prepared
460
+ }
461
+ BoundBatchStatement :: Query ( _) => None ,
462
+ } ;
463
+ }
464
+
465
+ let stmnt = match & statement {
466
+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
467
+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
468
+ } ,
469
+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
470
+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
471
+ } ,
472
+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
473
+ text : Cow :: Borrowed ( & q. contents ) ,
474
+ } ,
475
+ } ;
476
+
477
+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
478
+ BoundBatchStatement :: Prepared ( ps, values) => {
479
+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
480
+ values. serialize ( & ctx, writer) . map ( Some )
481
+ }
482
+ BoundBatchStatement :: Bound ( b) => {
483
+ writer. append_serialize_row ( & b. values ) ;
484
+ Ok ( Some ( ( ) ) )
485
+ }
486
+ // query has no values
487
+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
488
+ } ) ?;
489
+
490
+ let new_statements_len = self
491
+ . statements_len
492
+ . checked_add ( 1 )
493
+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
494
+
495
+ /*** at this point nothing else should be fallible as we are going to be modifying
496
+ * fields that do not get reset ***/
497
+
498
+ self . statements_len = new_statements_len;
499
+
500
+ if let Some ( first_prepared) = first_prepared {
501
+ self . first_prepared = Some ( first_prepared) ;
502
+ }
503
+
504
+ let prepared = match statement {
505
+ BoundBatchStatement :: Prepared ( ps, _) => Cow :: Owned ( ps) ,
506
+ BoundBatchStatement :: Bound ( b) => b. prepared ,
507
+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
508
+ } ;
509
+
510
+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
511
+ self . prepared
512
+ . insert ( prepared. get_id ( ) . to_owned ( ) , prepared. into_owned ( ) ) ;
513
+ }
514
+
515
+ Ok ( ( ) )
516
+ }
517
+
518
+ #[ allow( clippy:: result_large_err) ]
406
519
fn serialize_from_batch_statement < T > (
407
520
& mut self ,
408
521
statement : & BatchStatement ,
@@ -477,3 +590,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
477
590
n_statements : n_statements as usize ,
478
591
}
479
592
}
593
+
594
+ /// This enum represents a CQL statement, that can be part of batch and its values
595
+ #[ derive( Clone ) ]
596
+ #[ non_exhaustive]
597
+ pub enum BoundBatchStatement < ' p , V : SerializeRow > {
598
+ /// A prepared statement and its not-yet serialized values
599
+ Prepared ( PreparedStatement , V ) ,
600
+ /// A statement whose values have already been bound (and thus serialized)
601
+ Bound ( BoundStatement < ' p > ) ,
602
+ /// An unprepared statement with no values
603
+ Query ( Statement ) ,
604
+ }
605
+
606
+ impl < ' p > From < BoundStatement < ' p > > for BoundBatchStatement < ' p , ( ) > {
607
+ fn from ( b : BoundStatement < ' p > ) -> Self {
608
+ BoundBatchStatement :: Bound ( b)
609
+ }
610
+ }
611
+
612
+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < ' static , V > {
613
+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
614
+ BoundBatchStatement :: Prepared ( p, v)
615
+ }
616
+ }
617
+
618
+ impl From < Statement > for BoundBatchStatement < ' static , ( ) > {
619
+ fn from ( s : Statement ) -> Self {
620
+ BoundBatchStatement :: Query ( s)
621
+ }
622
+ }
623
+
624
+ impl From < & str > for BoundBatchStatement < ' static , ( ) > {
625
+ fn from ( s : & str ) -> Self {
626
+ BoundBatchStatement :: Query ( Statement :: from ( s) )
627
+ }
628
+ }
629
+
630
+ /// An error type returned when adding a statement to a bounded batch fails
631
+ #[ non_exhaustive]
632
+ #[ derive( Error , Debug , Clone ) ]
633
+ pub enum BoundBatchStatementError {
634
+ /// Failed to serialize the batch statement
635
+ #[ error( transparent) ]
636
+ Statement ( #[ from] BatchStatementSerializationError ) ,
637
+ /// Failed to serialize statement's bound values.
638
+ #[ error( "Failed to calculate partition key" ) ]
639
+ PartitionKey ( #[ from] PartitionKeyError ) ,
640
+ /// Too many statements in the batch statement.
641
+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
642
+ TooManyQueriesInBatchStatement ,
643
+ }
644
+
645
+ impl Execute for BoundBatch {
646
+ async fn execute ( & self , session : & Session ) -> Result < QueryResult , ExecutionError > {
647
+ // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
648
+ // If users batch statements by shard, they will be rewarded with full shard awareness
649
+ let execution_profile = self
650
+ . get_execution_profile_handle ( )
651
+ . unwrap_or_else ( || session. get_default_execution_profile_handle ( ) )
652
+ . access ( ) ;
653
+
654
+ let consistency = self
655
+ . config
656
+ . consistency
657
+ . unwrap_or ( execution_profile. consistency ) ;
658
+
659
+ let serial_consistency = self
660
+ . config
661
+ . serial_consistency
662
+ . unwrap_or ( execution_profile. serial_consistency ) ;
663
+
664
+ let ( table, token) = self
665
+ . first_prepared
666
+ . as_ref ( )
667
+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
668
+ . unzip ( ) ;
669
+
670
+ let statement_info = RoutingInfo {
671
+ consistency,
672
+ serial_consistency,
673
+ token,
674
+ table,
675
+ is_confirmed_lwt : false ,
676
+ } ;
677
+
678
+ let span = RequestSpan :: new_batch ( ) ;
679
+
680
+ let ( run_request_result, coordinator) : (
681
+ RunRequestResult < NonErrorQueryResponse > ,
682
+ Coordinator ,
683
+ ) = session
684
+ . run_request (
685
+ statement_info,
686
+ & self . config ,
687
+ execution_profile,
688
+ |connection : Arc < Connection > ,
689
+ consistency : Consistency ,
690
+ execution_profile : & ExecutionProfileInner | {
691
+ let serial_consistency = self
692
+ . config
693
+ . serial_consistency
694
+ . unwrap_or ( execution_profile. serial_consistency ) ;
695
+ async move {
696
+ connection
697
+ . batch_with_consistency ( self , consistency, serial_consistency)
698
+ . await
699
+ . and_then ( QueryResponse :: into_non_error_query_response)
700
+ }
701
+ } ,
702
+ & span,
703
+ )
704
+ . instrument ( span. span ( ) . clone ( ) )
705
+ . await ?;
706
+
707
+ let result = match run_request_result {
708
+ RunRequestResult :: IgnoredWriteError => QueryResult :: mock_empty ( coordinator) ,
709
+ RunRequestResult :: Completed ( non_error_query_response) => {
710
+ let result = non_error_query_response. into_query_result ( coordinator) ?;
711
+ span. record_result_fields ( & result) ;
712
+ result
713
+ }
714
+ } ;
715
+
716
+ Ok ( result)
717
+ }
718
+ }
0 commit comments