@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
88} ;
99use scylla_cql:: frame:: request;
1010use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
1212use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13+ use thiserror:: Error ;
1314
1415use crate :: client:: execution_profile:: ExecutionProfileHandle ;
1516use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
@@ -242,6 +243,20 @@ impl BoundBatch {
242243 }
243244 }
244245
246+ /// Appends a new statement to the batch.
247+ pub fn append_statement < V : SerializeRow > (
248+ & mut self ,
249+ statement : impl Into < BoundBatchStatement < V > > ,
250+ ) -> Result < ( ) , BoundBatchStatementError > {
251+ let initial_len = self . buffer . len ( ) ;
252+ self . raw_append_statement ( statement) . inspect_err ( |_| {
253+ // if we error'd at any point we should put the buffer back to its old length to not
254+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
255+ // tries with a successful statement later
256+ self . buffer . truncate ( initial_len) ;
257+ } )
258+ }
259+
245260 pub ( crate ) fn from_batch (
246261 batch : & Batch ,
247262 values : impl BatchValues ,
@@ -348,6 +363,89 @@ impl BoundBatch {
348363 self . batch_type
349364 }
350365
366+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
367+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
368+ // `self` to be modified if an error occured because the caller will not reset them.
369+ fn raw_append_statement < V : SerializeRow > (
370+ & mut self ,
371+ statement : impl Into < BoundBatchStatement < V > > ,
372+ ) -> Result < ( ) , BoundBatchStatementError > {
373+ let mut statement = statement. into ( ) ;
374+ let mut first_prepared = None ;
375+
376+ if self . statements_len == 0 {
377+ // save it into a local variable for now in case a latter steps fails
378+ first_prepared = match statement {
379+ BoundBatchStatement :: Bound ( ref b) => {
380+ b. token ( ) ?. map ( |token| ( b. prepared . clone ( ) , token) )
381+ }
382+ BoundBatchStatement :: Prepared ( ps, values) => {
383+ let bound = ps
384+ . bind ( & values)
385+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
386+ let first_prepared =
387+ bound. token ( ) ?. map ( |token| ( bound. prepared . clone ( ) , token) ) ;
388+ // we already serialized it so to avoid re-serializing it, modify the statement to a
389+ // BoundStatement
390+ statement = BoundBatchStatement :: Bound ( bound) ;
391+ first_prepared
392+ }
393+ BoundBatchStatement :: Query ( _) => None ,
394+ } ;
395+ }
396+
397+ let stmnt = match & statement {
398+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
399+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
400+ } ,
401+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
402+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
403+ } ,
404+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
405+ text : Cow :: Borrowed ( & q. contents ) ,
406+ } ,
407+ } ;
408+
409+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
410+ BoundBatchStatement :: Prepared ( ps, values) => {
411+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
412+ values. serialize ( & ctx, writer) . map ( Some )
413+ }
414+ BoundBatchStatement :: Bound ( b) => {
415+ writer. append_serialize_row ( & b. values ) ;
416+ Ok ( Some ( ( ) ) )
417+ }
418+ // query has no values
419+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
420+ } ) ?;
421+
422+ let new_statements_len = self
423+ . statements_len
424+ . checked_add ( 1 )
425+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
426+
427+ /*** at this point nothing else should be fallible as we are going to be modifying
428+ * fields that do not get reset ***/
429+
430+ self . statements_len = new_statements_len;
431+
432+ if let Some ( first_prepared) = first_prepared {
433+ self . first_prepared = Some ( first_prepared) ;
434+ }
435+
436+ let prepared = match statement {
437+ BoundBatchStatement :: Prepared ( ps, _) => ps,
438+ BoundBatchStatement :: Bound ( b) => b. prepared ,
439+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
440+ } ;
441+
442+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
443+ self . prepared . insert ( prepared. get_id ( ) . to_owned ( ) , prepared) ;
444+ }
445+
446+ Ok ( ( ) )
447+ }
448+
351449 fn serialize_from_batch_statement < T > (
352450 & mut self ,
353451 statement : & BatchStatement ,
@@ -422,3 +520,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
422520 n_statements : n_statements as usize ,
423521 }
424522}
523+
524+ /// This enum represents a CQL statement, that can be part of batch and its values
525+ #[ derive( Clone ) ]
526+ #[ non_exhaustive]
527+ pub enum BoundBatchStatement < V : SerializeRow > {
528+ /// A prepared statement and its not-yet serialized values
529+ Prepared ( PreparedStatement , V ) ,
530+ /// A statement whose values have already been bound (and thus serialized)
531+ Bound ( BoundStatement ) ,
532+ /// An unprepared statement with no values
533+ Query ( Statement ) ,
534+ }
535+
536+ impl From < BoundStatement > for BoundBatchStatement < ( ) > {
537+ fn from ( b : BoundStatement ) -> Self {
538+ BoundBatchStatement :: Bound ( b)
539+ }
540+ }
541+
542+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < V > {
543+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
544+ BoundBatchStatement :: Prepared ( p, v)
545+ }
546+ }
547+
548+ impl From < Statement > for BoundBatchStatement < ( ) > {
549+ fn from ( s : Statement ) -> Self {
550+ BoundBatchStatement :: Query ( s)
551+ }
552+ }
553+
554+ impl From < & str > for BoundBatchStatement < ( ) > {
555+ fn from ( s : & str ) -> Self {
556+ BoundBatchStatement :: Query ( Statement :: from ( s) )
557+ }
558+ }
559+
560+ /// An error type returned when adding a statement to a bounded batch fails
561+ #[ non_exhaustive]
562+ #[ derive( Error , Debug , Clone ) ]
563+ pub enum BoundBatchStatementError {
564+ /// Failed to serialize the batch statement
565+ #[ error( transparent) ]
566+ Statement ( #[ from] BatchStatementSerializationError ) ,
567+ /// Failed to serialize statement's bound values.
568+ #[ error( "Failed to calculate partition key" ) ]
569+ PartitionKey ( #[ from] PartitionKeyError ) ,
570+ /// Too many statements in the batch statement.
571+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
572+ TooManyQueriesInBatchStatement ,
573+ }
0 commit comments