@@ -8,19 +8,28 @@ use scylla_cql::frame::frame_errors::{
88} ;
99use scylla_cql:: frame:: request;
1010use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
1212use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13+ use thiserror:: Error ;
14+ use tracing:: Instrument ;
1315
14- use crate :: client:: execution_profile:: ExecutionProfileHandle ;
16+ use crate :: client:: execution_profile:: { ExecutionProfileHandle , ExecutionProfileInner } ;
17+ use crate :: client:: session:: { RunRequestResult , Session } ;
1518use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
19+ use crate :: network:: Connection ;
20+ use crate :: observability:: driver_tracing:: RequestSpan ;
1621use crate :: observability:: history:: HistoryListener ;
1722use crate :: policies:: load_balancing:: LoadBalancingPolicy ;
23+ use crate :: policies:: load_balancing:: RoutingInfo ;
1824use crate :: policies:: retry:: RetryPolicy ;
25+ use crate :: response:: query_result:: QueryResult ;
26+ use crate :: response:: { NonErrorQueryResponse , QueryResponse } ;
1927use crate :: routing:: Token ;
2028use crate :: statement:: prepared:: { PartitionKeyError , PreparedStatement } ;
2129use crate :: statement:: unprepared:: Statement ;
2230
2331use super :: bound:: BoundStatement ;
32+ use super :: execute:: Execute ;
2433use super :: StatementConfig ;
2534use super :: { Consistency , SerialConsistency } ;
2635pub use crate :: frame:: request:: batch:: BatchType ;
@@ -254,8 +263,8 @@ pub struct BoundBatch {
254263 batch_type : BatchType ,
255264 pub ( crate ) buffer : Vec < u8 > ,
256265 pub ( crate ) prepared : HashMap < Bytes , PreparedStatement > ,
257- pub ( crate ) first_prepared : Option < ( PreparedStatement , Token ) > ,
258- pub ( crate ) statements_len : u16 ,
266+ first_prepared : Option < ( PreparedStatement , Token ) > ,
267+ statements_len : u16 ,
259268}
260269
261270impl BoundBatch {
@@ -266,6 +275,20 @@ impl BoundBatch {
266275 }
267276 }
268277
278+ /// Appends a new statement to the batch.
279+ pub fn append_statement < ' p , V : SerializeRow > (
280+ & mut self ,
281+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
282+ ) -> Result < ( ) , BoundBatchStatementError > {
283+ let initial_len = self . buffer . len ( ) ;
284+ self . raw_append_statement ( statement) . inspect_err ( |_| {
285+ // if we error'd at any point we should put the buffer back to its old length to not
286+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
287+ // tries with a successful statement later
288+ self . buffer . truncate ( initial_len) ;
289+ } )
290+ }
291+
269292 pub ( crate ) fn from_batch (
270293 batch : & Batch ,
271294 values : impl BatchValues ,
@@ -374,6 +397,95 @@ impl BoundBatch {
374397 self . batch_type
375398 }
376399
400+ pub fn statements_len ( & self ) -> u16 {
401+ self . statements_len
402+ }
403+
404+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
405+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
406+ // `self` to be modified if an error occured because the caller will not reset them.
407+ fn raw_append_statement < ' p , V : SerializeRow > (
408+ & mut self ,
409+ statement : impl Into < BoundBatchStatement < ' p , V > > ,
410+ ) -> Result < ( ) , BoundBatchStatementError > {
411+ let mut statement = statement. into ( ) ;
412+ let mut first_prepared = None ;
413+
414+ if self . statements_len == 0 {
415+ // save it into a local variable for now in case a latter steps fails
416+ first_prepared = match statement {
417+ BoundBatchStatement :: Bound ( ref b) => b
418+ . token ( ) ?
419+ . map ( |token| ( b. prepared . clone ( ) . into_owned ( ) , token) ) ,
420+ BoundBatchStatement :: Prepared ( ps, values) => {
421+ let bound = ps
422+ . into_bind ( & values)
423+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
424+ let first_prepared = bound
425+ . token ( ) ?
426+ . map ( |token| ( bound. prepared . clone ( ) . into_owned ( ) , token) ) ;
427+ // we already serialized it so to avoid re-serializing it, modify the statement to a
428+ // BoundStatement
429+ statement = BoundBatchStatement :: Bound ( bound) ;
430+ first_prepared
431+ }
432+ BoundBatchStatement :: Query ( _) => None ,
433+ } ;
434+ }
435+
436+ let stmnt = match & statement {
437+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
438+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
439+ } ,
440+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
441+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
442+ } ,
443+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
444+ text : Cow :: Borrowed ( & q. contents ) ,
445+ } ,
446+ } ;
447+
448+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
449+ BoundBatchStatement :: Prepared ( ps, values) => {
450+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
451+ values. serialize ( & ctx, writer) . map ( Some )
452+ }
453+ BoundBatchStatement :: Bound ( b) => {
454+ writer. append_serialize_row ( & b. values ) ;
455+ Ok ( Some ( ( ) ) )
456+ }
457+ // query has no values
458+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
459+ } ) ?;
460+
461+ let new_statements_len = self
462+ . statements_len
463+ . checked_add ( 1 )
464+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
465+
466+ /*** at this point nothing else should be fallible as we are going to be modifying
467+ * fields that do not get reset ***/
468+
469+ self . statements_len = new_statements_len;
470+
471+ if let Some ( first_prepared) = first_prepared {
472+ self . first_prepared = Some ( first_prepared) ;
473+ }
474+
475+ let prepared = match statement {
476+ BoundBatchStatement :: Prepared ( ps, _) => Cow :: Owned ( ps) ,
477+ BoundBatchStatement :: Bound ( b) => b. prepared ,
478+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
479+ } ;
480+
481+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
482+ self . prepared
483+ . insert ( prepared. get_id ( ) . to_owned ( ) , prepared. into_owned ( ) ) ;
484+ }
485+
486+ Ok ( ( ) )
487+ }
488+
377489 fn serialize_from_batch_statement < T > (
378490 & mut self ,
379491 statement : & BatchStatement ,
@@ -448,3 +560,126 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
448560 n_statements : n_statements as usize ,
449561 }
450562}
563+
564+ /// This enum represents a CQL statement, that can be part of batch and its values
565+ #[ derive( Clone ) ]
566+ #[ non_exhaustive]
567+ pub enum BoundBatchStatement < ' p , V : SerializeRow > {
568+ /// A prepared statement and its not-yet serialized values
569+ Prepared ( PreparedStatement , V ) ,
570+ /// A statement whose values have already been bound (and thus serialized)
571+ Bound ( BoundStatement < ' p > ) ,
572+ /// An unprepared statement with no values
573+ Query ( Statement ) ,
574+ }
575+
576+ impl < ' p > From < BoundStatement < ' p > > for BoundBatchStatement < ' p , ( ) > {
577+ fn from ( b : BoundStatement < ' p > ) -> Self {
578+ BoundBatchStatement :: Bound ( b)
579+ }
580+ }
581+
582+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < ' static , V > {
583+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
584+ BoundBatchStatement :: Prepared ( p, v)
585+ }
586+ }
587+
588+ impl From < Statement > for BoundBatchStatement < ' static , ( ) > {
589+ fn from ( s : Statement ) -> Self {
590+ BoundBatchStatement :: Query ( s)
591+ }
592+ }
593+
594+ impl From < & str > for BoundBatchStatement < ' static , ( ) > {
595+ fn from ( s : & str ) -> Self {
596+ BoundBatchStatement :: Query ( Statement :: from ( s) )
597+ }
598+ }
599+
600+ /// An error type returned when adding a statement to a bounded batch fails
601+ #[ non_exhaustive]
602+ #[ derive( Error , Debug , Clone ) ]
603+ pub enum BoundBatchStatementError {
604+ /// Failed to serialize the batch statement
605+ #[ error( transparent) ]
606+ Statement ( #[ from] BatchStatementSerializationError ) ,
607+ /// Failed to serialize statement's bound values.
608+ #[ error( "Failed to calculate partition key" ) ]
609+ PartitionKey ( #[ from] PartitionKeyError ) ,
610+ /// Too many statements in the batch statement.
611+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
612+ TooManyQueriesInBatchStatement ,
613+ }
614+
615+ impl Execute for BoundBatch {
616+ async fn execute ( & self , session : & Session ) -> Result < QueryResult , ExecutionError > {
617+ // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
618+ // If users batch statements by shard, they will be rewarded with full shard awareness
619+ let execution_profile = self
620+ . get_execution_profile_handle ( )
621+ . unwrap_or_else ( || session. get_default_execution_profile_handle ( ) )
622+ . access ( ) ;
623+
624+ let consistency = self
625+ . config
626+ . consistency
627+ . unwrap_or ( execution_profile. consistency ) ;
628+
629+ let serial_consistency = self
630+ . config
631+ . serial_consistency
632+ . unwrap_or ( execution_profile. serial_consistency ) ;
633+
634+ let ( table, token) = self
635+ . first_prepared
636+ . as_ref ( )
637+ . and_then ( |( ps, token) | ps. get_table_spec ( ) . map ( |table| ( table, * token) ) )
638+ . unzip ( ) ;
639+
640+ let statement_info = RoutingInfo {
641+ consistency,
642+ serial_consistency,
643+ token,
644+ table,
645+ is_confirmed_lwt : false ,
646+ } ;
647+
648+ let span = RequestSpan :: new_batch ( ) ;
649+
650+ let run_request_result: RunRequestResult < NonErrorQueryResponse > = session
651+ . run_request (
652+ statement_info,
653+ & self . config ,
654+ execution_profile,
655+ |connection : Arc < Connection > ,
656+ consistency : Consistency ,
657+ execution_profile : & ExecutionProfileInner | {
658+ let serial_consistency = self
659+ . config
660+ . serial_consistency
661+ . unwrap_or ( execution_profile. serial_consistency ) ;
662+ async move {
663+ connection
664+ . batch_with_consistency ( self , consistency, serial_consistency)
665+ . await
666+ . and_then ( QueryResponse :: into_non_error_query_response)
667+ }
668+ } ,
669+ & span,
670+ )
671+ . instrument ( span. span ( ) . clone ( ) )
672+ . await ?;
673+
674+ let result = match run_request_result {
675+ RunRequestResult :: IgnoredWriteError => QueryResult :: mock_empty ( ) ,
676+ RunRequestResult :: Completed ( non_error_query_response) => {
677+ let result = non_error_query_response. into_query_result ( ) ?;
678+ span. record_result_fields ( & result) ;
679+ result
680+ }
681+ } ;
682+
683+ Ok ( result)
684+ }
685+ }
0 commit comments