@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
88} ;
99use scylla_cql:: frame:: request;
1010use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
1212use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13+ use thiserror:: Error ;
1314
1415use crate :: client:: execution_profile:: ExecutionProfileHandle ;
1516use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
@@ -266,6 +267,20 @@ impl BoundBatch {
266267 }
267268 }
268269
270+ /// Appends a new statement to the batch.
271+ pub fn append_statement < V : SerializeRow > (
272+ & mut self ,
273+ statement : impl Into < BoundBatchStatement < V > > ,
274+ ) -> Result < ( ) , BoundBatchStatementError > {
275+ let initial_len = self . buffer . len ( ) ;
276+ self . raw_append_statement ( statement) . inspect_err ( |_| {
277+ // if we error'd at any point we should put the buffer back to its old length to not
278+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
279+ // tries with a successful statement later
280+ self . buffer . truncate ( initial_len) ;
281+ } )
282+ }
283+
269284 pub ( crate ) fn from_batch (
270285 batch : & Batch ,
271286 values : impl BatchValues ,
@@ -372,6 +387,89 @@ impl BoundBatch {
372387 self . batch_type
373388 }
374389
390+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
391+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
392+ // `self` to be modified if an error occured because the caller will not reset them.
393+ fn raw_append_statement < V : SerializeRow > (
394+ & mut self ,
395+ statement : impl Into < BoundBatchStatement < V > > ,
396+ ) -> Result < ( ) , BoundBatchStatementError > {
397+ let mut statement = statement. into ( ) ;
398+ let mut first_prepared = None ;
399+
400+ if self . statements_len == 0 {
401+ // save it into a local variable for now in case a latter steps fails
402+ first_prepared = match statement {
403+ BoundBatchStatement :: Bound ( ref b) => {
404+ b. token ( ) ?. map ( |token| ( b. prepared . clone ( ) , token) )
405+ }
406+ BoundBatchStatement :: Prepared ( ps, values) => {
407+ let bound = ps
408+ . bind ( & values)
409+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
410+ let first_prepared =
411+ bound. token ( ) ?. map ( |token| ( bound. prepared . clone ( ) , token) ) ;
412+ // we already serialized it so to avoid re-serializing it, modify the statement to a
413+ // BoundStatement
414+ statement = BoundBatchStatement :: Bound ( bound) ;
415+ first_prepared
416+ }
417+ BoundBatchStatement :: Query ( _) => None ,
418+ } ;
419+ }
420+
421+ let stmnt = match & statement {
422+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
423+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
424+ } ,
425+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
426+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
427+ } ,
428+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
429+ text : Cow :: Borrowed ( & q. contents ) ,
430+ } ,
431+ } ;
432+
433+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
434+ BoundBatchStatement :: Prepared ( ps, values) => {
435+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
436+ values. serialize ( & ctx, writer) . map ( Some )
437+ }
438+ BoundBatchStatement :: Bound ( b) => {
439+ writer. append_serialize_row ( & b. values ) ;
440+ Ok ( Some ( ( ) ) )
441+ }
442+ // query has no values
443+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
444+ } ) ?;
445+
446+ let new_statements_len = self
447+ . statements_len
448+ . checked_add ( 1 )
449+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
450+
451+ /*** at this point nothing else should be fallible as we are going to be modifying
452+ * fields that do not get reset ***/
453+
454+ self . statements_len = new_statements_len;
455+
456+ if let Some ( first_prepared) = first_prepared {
457+ self . first_prepared = Some ( first_prepared) ;
458+ }
459+
460+ let prepared = match statement {
461+ BoundBatchStatement :: Prepared ( ps, _) => ps,
462+ BoundBatchStatement :: Bound ( b) => b. prepared ,
463+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
464+ } ;
465+
466+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
467+ self . prepared . insert ( prepared. get_id ( ) . to_owned ( ) , prepared) ;
468+ }
469+
470+ Ok ( ( ) )
471+ }
472+
375473 fn serialize_from_batch_statement < T > (
376474 & mut self ,
377475 statement : & BatchStatement ,
@@ -446,3 +544,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
446544 n_statements : n_statements as usize ,
447545 }
448546}
547+
548+ /// This enum represents a CQL statement, that can be part of batch and its values
549+ #[ derive( Clone ) ]
550+ #[ non_exhaustive]
551+ pub enum BoundBatchStatement < V : SerializeRow > {
552+ /// A prepared statement and its not-yet serialized values
553+ Prepared ( PreparedStatement , V ) ,
554+ /// A statement whose values have already been bound (and thus serialized)
555+ Bound ( BoundStatement ) ,
556+ /// An unprepared statement with no values
557+ Query ( Statement ) ,
558+ }
559+
560+ impl From < BoundStatement > for BoundBatchStatement < ( ) > {
561+ fn from ( b : BoundStatement ) -> Self {
562+ BoundBatchStatement :: Bound ( b)
563+ }
564+ }
565+
566+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < V > {
567+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
568+ BoundBatchStatement :: Prepared ( p, v)
569+ }
570+ }
571+
572+ impl From < Statement > for BoundBatchStatement < ( ) > {
573+ fn from ( s : Statement ) -> Self {
574+ BoundBatchStatement :: Query ( s)
575+ }
576+ }
577+
578+ impl From < & str > for BoundBatchStatement < ( ) > {
579+ fn from ( s : & str ) -> Self {
580+ BoundBatchStatement :: Query ( Statement :: from ( s) )
581+ }
582+ }
583+
584+ /// An error type returned when adding a statement to a bounded batch fails
585+ #[ non_exhaustive]
586+ #[ derive( Error , Debug , Clone ) ]
587+ pub enum BoundBatchStatementError {
588+ /// Failed to serialize the batch statement
589+ #[ error( transparent) ]
590+ Statement ( #[ from] BatchStatementSerializationError ) ,
591+ /// Failed to serialize statement's bound values.
592+ #[ error( "Failed to calculate partition key" ) ]
593+ PartitionKey ( #[ from] PartitionKeyError ) ,
594+ /// Too many statements in the batch statement.
595+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
596+ TooManyQueriesInBatchStatement ,
597+ }
0 commit comments