@@ -8,19 +8,28 @@ use scylla_cql::frame::frame_errors::{
8
8
} ;
9
9
use scylla_cql:: frame:: request;
10
10
use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11
- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11
+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
12
12
use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13
+ use thiserror:: Error ;
14
+ use tracing:: Instrument ;
13
15
14
- use crate :: client:: execution_profile:: ExecutionProfileHandle ;
16
+ use crate :: client:: execution_profile:: { ExecutionProfileHandle , ExecutionProfileInner } ;
17
+ use crate :: client:: session:: { RunRequestResult , Session } ;
15
18
use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
19
+ use crate :: network:: Connection ;
20
+ use crate :: observability:: driver_tracing:: RequestSpan ;
16
21
use crate :: observability:: history:: HistoryListener ;
17
22
use crate :: policies:: load_balancing:: LoadBalancingPolicy ;
23
+ use crate :: policies:: load_balancing:: RoutingInfo ;
18
24
use crate :: policies:: retry:: RetryPolicy ;
25
+ use crate :: response:: query_result:: QueryResult ;
26
+ use crate :: response:: { Coordinator , NonErrorQueryResponse , QueryResponse } ;
19
27
use crate :: routing:: Token ;
20
28
use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
21
29
use crate :: statement:: unprepared:: Statement ;
22
30
23
31
use super :: bound:: BoundStatement ;
32
+ use super :: execute:: Execute ;
24
33
use super :: StatementConfig ;
25
34
use super :: { Consistency , SerialConsistency } ;
26
35
pub use crate :: frame:: request:: batch:: BatchType ;
@@ -254,8 +263,8 @@ pub struct BoundBatch {
254
263
batch_type : BatchType ,
255
264
pub ( crate ) buffer : Vec < u8 > ,
256
265
pub ( crate ) prepared : HashMap < Bytes , PreparedStatement > ,
257
- pub ( crate ) first_prepared : Option < ( PreparedStatement , Token ) > ,
258
- pub ( crate ) statements_len : u16 ,
266
+ first_prepared : Option < ( PreparedStatement , Token ) > ,
267
+ statements_len : u16 ,
259
268
}
260
269
261
270
impl BoundBatch {
@@ -266,6 +275,20 @@ impl BoundBatch {
266
275
}
267
276
}
268
277
278
+ /// Appends a new statement to the batch.
279
+ pub fn append_statement < ' p , V : SerializeRow > (
280
+ & mut self ,
281
+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
282
+ ) -> Result < ( ) , BoundBatchStatementError > {
283
+ let initial_len = self . buffer . len ( ) ;
284
+ self . raw_append_statement ( statement) . inspect_err ( |_| {
285
+ // if we error'd at any point we should put the buffer back to its old length to not
286
+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
287
+ // tries with a successful statement later
288
+ self . buffer . truncate ( initial_len) ;
289
+ } )
290
+ }
291
+
269
292
#[ allow( clippy:: result_large_err) ]
270
293
pub ( crate ) fn from_batch (
271
294
batch : & Batch ,
@@ -375,6 +398,96 @@ impl BoundBatch {
375
398
self . batch_type
376
399
}
377
400
401
+ pub fn statements_len ( & self ) -> u16 {
402
+ self . statements_len
403
+ }
404
+
405
+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
406
+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
407
+ // `self` to be modified if an error occured because the caller will not reset them.
408
+ fn raw_append_statement < ' p , V : SerializeRow > (
409
+ & mut self ,
410
+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
411
+ ) -> Result < ( ) , BoundBatchStatementError > {
412
+ let mut statement = statement. into ( ) ;
413
+ let mut first_prepared = None ;
414
+
415
+ if self . statements_len == 0 {
416
+ // save it into a local variable for now in case a latter steps fails
417
+ first_prepared = match statement {
418
+ BoundBatchStatement :: Bound ( ref b) => b
419
+ . token ( ) ?
420
+ . map ( |token| ( b. prepared . clone ( ) . into_owned ( ) , token) ) ,
421
+ BoundBatchStatement :: Prepared ( ps, values) => {
422
+ let bound = ps
423
+ . into_bind ( & values)
424
+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
425
+ let first_prepared = bound
426
+ . token ( ) ?
427
+ . map ( |token| ( bound. prepared . clone ( ) . into_owned ( ) , token) ) ;
428
+ // we already serialized it so to avoid re-serializing it, modify the statement to a
429
+ // BoundStatement
430
+ statement = BoundBatchStatement :: Bound ( bound) ;
431
+ first_prepared
432
+ }
433
+ BoundBatchStatement :: Query ( _) => None ,
434
+ } ;
435
+ }
436
+
437
+ let stmnt = match & statement {
438
+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
439
+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
440
+ } ,
441
+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
442
+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
443
+ } ,
444
+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
445
+ text : Cow :: Borrowed ( & q. contents ) ,
446
+ } ,
447
+ } ;
448
+
449
+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
450
+ BoundBatchStatement :: Prepared ( ps, values) => {
451
+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
452
+ values. serialize ( & ctx, writer) . map ( Some )
453
+ }
454
+ BoundBatchStatement :: Bound ( b) => {
455
+ writer. append_serialize_row ( & b. values ) ;
456
+ Ok ( Some ( ( ) ) )
457
+ }
458
+ // query has no values
459
+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
460
+ } ) ?;
461
+
462
+ let new_statements_len = self
463
+ . statements_len
464
+ . checked_add ( 1 )
465
+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
466
+
467
+ /*** at this point nothing else should be fallible as we are going to be modifying
468
+ * fields that do not get reset ***/
469
+
470
+ self . statements_len = new_statements_len;
471
+
472
+ if let Some ( first_prepared) = first_prepared {
473
+ self . first_prepared = Some ( first_prepared) ;
474
+ }
475
+
476
+ let prepared = match statement {
477
+ BoundBatchStatement :: Prepared ( ps, _) => Cow :: Owned ( ps) ,
478
+ BoundBatchStatement :: Bound ( b) => b. prepared ,
479
+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
480
+ } ;
481
+
482
+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
483
+ self . prepared
484
+ . insert ( prepared. get_id ( ) . to_owned ( ) , prepared. into_owned ( ) ) ;
485
+ }
486
+
487
+ Ok ( ( ) )
488
+ }
489
+
490
+ #[ allow( clippy:: result_large_err) ]
378
491
fn serialize_from_batch_statement < T > (
379
492
& mut self ,
380
493
statement : & BatchStatement ,
@@ -449,3 +562,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
449
562
n_statements : n_statements as usize ,
450
563
}
451
564
}
565
+
566
+ /// This enum represents a CQL statement, that can be part of batch and its values
567
+ #[ derive( Clone ) ]
568
+ #[ non_exhaustive]
569
+ pub enum BoundBatchStatement < ' p , V : SerializeRow > {
570
+ /// A prepared statement and its not-yet serialized values
571
+ Prepared ( PreparedStatement , V ) ,
572
+ /// A statement whose values have already been bound (and thus serialized)
573
+ Bound ( BoundStatement < ' p > ) ,
574
+ /// An unprepared statement with no values
575
+ Query ( Statement ) ,
576
+ }
577
+
578
+ impl < ' p > From < BoundStatement < ' p > > for BoundBatchStatement < ' p , ( ) > {
579
+ fn from ( b : BoundStatement < ' p > ) -> Self {
580
+ BoundBatchStatement :: Bound ( b)
581
+ }
582
+ }
583
+
584
+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < ' static , V > {
585
+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
586
+ BoundBatchStatement :: Prepared ( p, v)
587
+ }
588
+ }
589
+
590
+ impl From < Statement > for BoundBatchStatement < ' static , ( ) > {
591
+ fn from ( s : Statement ) -> Self {
592
+ BoundBatchStatement :: Query ( s)
593
+ }
594
+ }
595
+
596
+ impl From < & str > for BoundBatchStatement < ' static , ( ) > {
597
+ fn from ( s : & str ) -> Self {
598
+ BoundBatchStatement :: Query ( Statement :: from ( s) )
599
+ }
600
+ }
601
+
602
+ /// An error type returned when adding a statement to a bounded batch fails
603
+ #[ non_exhaustive]
604
+ #[ derive( Error , Debug , Clone ) ]
605
+ pub enum BoundBatchStatementError {
606
+ /// Failed to serialize the batch statement
607
+ #[ error( transparent) ]
608
+ Statement ( #[ from] BatchStatementSerializationError ) ,
609
+ /// Failed to serialize statement's bound values.
610
+ #[ error( "Failed to calculate partition key" ) ]
611
+ PartitionKey ( #[ from] PartitionKeyError ) ,
612
+ /// Too many statements in the batch statement.
613
+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
614
+ TooManyQueriesInBatchStatement ,
615
+ }
616
+
617
+ impl Execute for BoundBatch {
618
+ async fn execute ( & self , session : & Session ) -> Result < QueryResult , ExecutionError > {
619
+ // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
620
+ // If users batch statements by shard, they will be rewarded with full shard awareness
621
+ let execution_profile = self
622
+ . get_execution_profile_handle ( )
623
+ . unwrap_or_else ( || session. get_default_execution_profile_handle ( ) )
624
+ . access ( ) ;
625
+
626
+ let consistency = self
627
+ . config
628
+ . consistency
629
+ . unwrap_or ( execution_profile. consistency ) ;
630
+
631
+ let serial_consistency = self
632
+ . config
633
+ . serial_consistency
634
+ . unwrap_or ( execution_profile. serial_consistency ) ;
635
+
636
+ let ( table, token) = self
637
+ . first_prepared
638
+ . as_ref ( )
639
+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
640
+ . unzip ( ) ;
641
+
642
+ let statement_info = RoutingInfo {
643
+ consistency,
644
+ serial_consistency,
645
+ token,
646
+ table,
647
+ is_confirmed_lwt : false ,
648
+ } ;
649
+
650
+ let span = RequestSpan :: new_batch ( ) ;
651
+
652
+ let ( run_request_result, coordinator) : (
653
+ RunRequestResult < NonErrorQueryResponse > ,
654
+ Coordinator ,
655
+ ) = session
656
+ . run_request (
657
+ statement_info,
658
+ & self . config ,
659
+ execution_profile,
660
+ |connection : Arc < Connection > ,
661
+ consistency : Consistency ,
662
+ execution_profile : & ExecutionProfileInner | {
663
+ let serial_consistency = self
664
+ . config
665
+ . serial_consistency
666
+ . unwrap_or ( execution_profile. serial_consistency ) ;
667
+ async move {
668
+ connection
669
+ . batch_with_consistency ( self , consistency, serial_consistency)
670
+ . await
671
+ . and_then ( QueryResponse :: into_non_error_query_response)
672
+ }
673
+ } ,
674
+ & span,
675
+ )
676
+ . instrument ( span. span ( ) . clone ( ) )
677
+ . await ?;
678
+
679
+ let result = match run_request_result {
680
+ RunRequestResult :: IgnoredWriteError => QueryResult :: mock_empty ( coordinator) ,
681
+ RunRequestResult :: Completed ( non_error_query_response) => {
682
+ let result = non_error_query_response. into_query_result ( coordinator) ?;
683
+ span. record_result_fields ( & result) ;
684
+ result
685
+ }
686
+ } ;
687
+
688
+ Ok ( result)
689
+ }
690
+ }
0 commit comments