@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
8
8
} ;
9
9
use scylla_cql:: frame:: request;
10
10
use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11
- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11
+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
12
12
use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13
+ use thiserror:: Error ;
13
14
14
15
use crate :: client:: execution_profile:: ExecutionProfileHandle ;
15
16
use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
@@ -266,6 +267,20 @@ impl BoundBatch {
266
267
}
267
268
}
268
269
270
+ /// Appends a new statement to the batch.
271
+ pub fn append_statement < V : SerializeRow > (
272
+ & mut self ,
273
+ statement : impl Into < BoundBatchStatement < V > > ,
274
+ ) -> Result < ( ) , BoundBatchStatementError > {
275
+ let initial_len = self . buffer . len ( ) ;
276
+ self . raw_append_statement ( statement) . inspect_err ( |_| {
277
+ // if we error'd at any point we should put the buffer back to its old length to not
278
+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
279
+ // tries with a successful statement later
280
+ self . buffer . truncate ( initial_len) ;
281
+ } )
282
+ }
283
+
269
284
pub ( crate ) fn from_batch (
270
285
batch : & Batch ,
271
286
values : impl BatchValues ,
@@ -372,6 +387,89 @@ impl BoundBatch {
372
387
self . batch_type
373
388
}
374
389
390
+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
391
+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
392
+ // `self` to be modified if an error occured because the caller will not reset them.
393
+ fn raw_append_statement < V : SerializeRow > (
394
+ & mut self ,
395
+ statement : impl Into < BoundBatchStatement < V > > ,
396
+ ) -> Result < ( ) , BoundBatchStatementError > {
397
+ let mut statement = statement. into ( ) ;
398
+ let mut first_prepared = None ;
399
+
400
+ if self . statements_len == 0 {
401
+ // save it into a local variable for now in case a latter steps fails
402
+ first_prepared = match statement {
403
+ BoundBatchStatement :: Bound ( ref b) => {
404
+ b. token ( ) ?. map ( |token| ( b. prepared . clone ( ) , token) )
405
+ }
406
+ BoundBatchStatement :: Prepared ( ps, values) => {
407
+ let bound = ps
408
+ . bind ( & values)
409
+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
410
+ let first_prepared =
411
+ bound. token ( ) ?. map ( |token| ( bound. prepared . clone ( ) , token) ) ;
412
+ // we already serialized it so to avoid re-serializing it, modify the statement to a
413
+ // BoundStatement
414
+ statement = BoundBatchStatement :: Bound ( bound) ;
415
+ first_prepared
416
+ }
417
+ BoundBatchStatement :: Query ( _) => None ,
418
+ } ;
419
+ }
420
+
421
+ let stmnt = match & statement {
422
+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
423
+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
424
+ } ,
425
+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
426
+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
427
+ } ,
428
+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
429
+ text : Cow :: Borrowed ( & q. contents ) ,
430
+ } ,
431
+ } ;
432
+
433
+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
434
+ BoundBatchStatement :: Prepared ( ps, values) => {
435
+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
436
+ values. serialize ( & ctx, writer) . map ( Some )
437
+ }
438
+ BoundBatchStatement :: Bound ( b) => {
439
+ writer. append_serialize_row ( & b. values ) ;
440
+ Ok ( Some ( ( ) ) )
441
+ }
442
+ // query has no values
443
+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
444
+ } ) ?;
445
+
446
+ let new_statements_len = self
447
+ . statements_len
448
+ . checked_add ( 1 )
449
+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
450
+
451
+ /*** at this point nothing else should be fallible as we are going to be modifying
452
+ * fields that do not get reset ***/
453
+
454
+ self . statements_len = new_statements_len;
455
+
456
+ if let Some ( first_prepared) = first_prepared {
457
+ self . first_prepared = Some ( first_prepared) ;
458
+ }
459
+
460
+ let prepared = match statement {
461
+ BoundBatchStatement :: Prepared ( ps, _) => ps,
462
+ BoundBatchStatement :: Bound ( b) => b. prepared ,
463
+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
464
+ } ;
465
+
466
+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
467
+ self . prepared . insert ( prepared. get_id ( ) . to_owned ( ) , prepared) ;
468
+ }
469
+
470
+ Ok ( ( ) )
471
+ }
472
+
375
473
fn serialize_from_batch_statement < T > (
376
474
& mut self ,
377
475
statement : & BatchStatement ,
@@ -446,3 +544,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
446
544
n_statements : n_statements as usize ,
447
545
}
448
546
}
547
+
548
+ /// This enum represents a CQL statement, that can be part of batch and its values
549
+ #[ derive( Clone ) ]
550
+ #[ non_exhaustive]
551
+ pub enum BoundBatchStatement < V : SerializeRow > {
552
+ /// A prepared statement and its not-yet serialized values
553
+ Prepared ( PreparedStatement , V ) ,
554
+ /// A statement whose values have already been bound (and thus serialized)
555
+ Bound ( BoundStatement ) ,
556
+ /// An unprepared statement with no values
557
+ Query ( Statement ) ,
558
+ }
559
+
560
+ impl From < BoundStatement > for BoundBatchStatement < ( ) > {
561
+ fn from ( b : BoundStatement ) -> Self {
562
+ BoundBatchStatement :: Bound ( b)
563
+ }
564
+ }
565
+
566
+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < V > {
567
+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
568
+ BoundBatchStatement :: Prepared ( p, v)
569
+ }
570
+ }
571
+
572
+ impl From < Statement > for BoundBatchStatement < ( ) > {
573
+ fn from ( s : Statement ) -> Self {
574
+ BoundBatchStatement :: Query ( s)
575
+ }
576
+ }
577
+
578
+ impl From < & str > for BoundBatchStatement < ( ) > {
579
+ fn from ( s : & str ) -> Self {
580
+ BoundBatchStatement :: Query ( Statement :: from ( s) )
581
+ }
582
+ }
583
+
584
+ /// An error type returned when adding a statement to a bounded batch fails
585
+ #[ non_exhaustive]
586
+ #[ derive( Error , Debug , Clone ) ]
587
+ pub enum BoundBatchStatementError {
588
+ /// Failed to serialize the batch statement
589
+ #[ error( transparent) ]
590
+ Statement ( #[ from] BatchStatementSerializationError ) ,
591
+ /// Failed to serialize statement's bound values.
592
+ #[ error( "Failed to calculate partition key" ) ]
593
+ PartitionKey ( #[ from] PartitionKeyError ) ,
594
+ /// Too many statements in the batch statement.
595
+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
596
+ TooManyQueriesInBatchStatement ,
597
+ }
0 commit comments