@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
8
8
} ;
9
9
use scylla_cql:: frame:: request;
10
10
use scylla_cql:: serialize:: batch:: { BatchValues , BatchValuesIterator } ;
11
- use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializedValues } ;
11
+ use scylla_cql:: serialize:: row:: { RowSerializationContext , SerializeRow , SerializedValues } ;
12
12
use scylla_cql:: serialize:: { RowWriter , SerializationError } ;
13
+ use thiserror:: Error ;
13
14
14
15
use crate :: client:: execution_profile:: ExecutionProfileHandle ;
15
16
use crate :: errors:: { BadQuery , ExecutionError , RequestAttemptError } ;
@@ -242,6 +243,20 @@ impl BoundBatch {
242
243
}
243
244
}
244
245
246
+ /// Appends a new statement to the batch.
247
+ pub fn append_statement < V : SerializeRow > (
248
+ & mut self ,
249
+ statement : impl Into < BoundBatchStatement < V > > ,
250
+ ) -> Result < ( ) , BoundBatchStatementError > {
251
+ let initial_len = self . buffer . len ( ) ;
252
+ self . raw_append_statement ( statement) . inspect_err ( |_| {
253
+ // if we error'd at any point we should put the buffer back to its old length to not
254
+ // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
255
+ // tries with a successful statement later
256
+ self . buffer . truncate ( initial_len) ;
257
+ } )
258
+ }
259
+
245
260
pub ( crate ) fn from_batch (
246
261
batch : & Batch ,
247
262
values : impl BatchValues ,
@@ -348,6 +363,89 @@ impl BoundBatch {
348
363
self . batch_type
349
364
}
350
365
366
+ // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
367
+ // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
368
+ // `self` to be modified if an error occured because the caller will not reset them.
369
+ fn raw_append_statement < V : SerializeRow > (
370
+ & mut self ,
371
+ statement : impl Into < BoundBatchStatement < V > > ,
372
+ ) -> Result < ( ) , BoundBatchStatementError > {
373
+ let mut statement = statement. into ( ) ;
374
+ let mut first_prepared = None ;
375
+
376
+ if self . statements_len == 0 {
377
+ // save it into a local variable for now in case a latter steps fails
378
+ first_prepared = match statement {
379
+ BoundBatchStatement :: Bound ( ref b) => {
380
+ b. token ( ) ?. map ( |token| ( b. prepared . clone ( ) , token) )
381
+ }
382
+ BoundBatchStatement :: Prepared ( ps, values) => {
383
+ let bound = ps
384
+ . bind ( & values)
385
+ . map_err ( BatchStatementSerializationError :: ValuesSerialiation ) ?;
386
+ let first_prepared =
387
+ bound. token ( ) ?. map ( |token| ( bound. prepared . clone ( ) , token) ) ;
388
+ // we already serialized it so to avoid re-serializing it, modify the statement to a
389
+ // BoundStatement
390
+ statement = BoundBatchStatement :: Bound ( bound) ;
391
+ first_prepared
392
+ }
393
+ BoundBatchStatement :: Query ( _) => None ,
394
+ } ;
395
+ }
396
+
397
+ let stmnt = match & statement {
398
+ BoundBatchStatement :: Prepared ( ps, _) => request:: batch:: BatchStatement :: Prepared {
399
+ id : Cow :: Borrowed ( ps. get_id ( ) ) ,
400
+ } ,
401
+ BoundBatchStatement :: Bound ( b) => request:: batch:: BatchStatement :: Prepared {
402
+ id : Cow :: Borrowed ( b. prepared . get_id ( ) ) ,
403
+ } ,
404
+ BoundBatchStatement :: Query ( q) => request:: batch:: BatchStatement :: Query {
405
+ text : Cow :: Borrowed ( & q. contents ) ,
406
+ } ,
407
+ } ;
408
+
409
+ serialize_statement ( stmnt, & mut self . buffer , |writer| match & statement {
410
+ BoundBatchStatement :: Prepared ( ps, values) => {
411
+ let ctx = RowSerializationContext :: from_prepared ( ps. get_prepared_metadata ( ) ) ;
412
+ values. serialize ( & ctx, writer) . map ( Some )
413
+ }
414
+ BoundBatchStatement :: Bound ( b) => {
415
+ writer. append_serialize_row ( & b. values ) ;
416
+ Ok ( Some ( ( ) ) )
417
+ }
418
+ // query has no values
419
+ BoundBatchStatement :: Query ( _) => Ok ( Some ( ( ) ) ) ,
420
+ } ) ?;
421
+
422
+ let new_statements_len = self
423
+ . statements_len
424
+ . checked_add ( 1 )
425
+ . ok_or ( BoundBatchStatementError :: TooManyQueriesInBatchStatement ) ?;
426
+
427
+ /*** at this point nothing else should be fallible as we are going to be modifying
428
+ * fields that do not get reset ***/
429
+
430
+ self . statements_len = new_statements_len;
431
+
432
+ if let Some ( first_prepared) = first_prepared {
433
+ self . first_prepared = Some ( first_prepared) ;
434
+ }
435
+
436
+ let prepared = match statement {
437
+ BoundBatchStatement :: Prepared ( ps, _) => ps,
438
+ BoundBatchStatement :: Bound ( b) => b. prepared ,
439
+ BoundBatchStatement :: Query ( _) => return Ok ( ( ) ) ,
440
+ } ;
441
+
442
+ if !self . prepared . contains_key ( prepared. get_id ( ) ) {
443
+ self . prepared . insert ( prepared. get_id ( ) . to_owned ( ) , prepared) ;
444
+ }
445
+
446
+ Ok ( ( ) )
447
+ }
448
+
351
449
fn serialize_from_batch_statement < T > (
352
450
& mut self ,
353
451
statement : & BatchStatement ,
@@ -422,3 +520,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
422
520
n_statements : n_statements as usize ,
423
521
}
424
522
}
523
+
524
+ /// This enum represents a CQL statement, that can be part of batch and its values
525
+ #[ derive( Clone ) ]
526
+ #[ non_exhaustive]
527
+ pub enum BoundBatchStatement < V : SerializeRow > {
528
+ /// A prepared statement and its not-yet serialized values
529
+ Prepared ( PreparedStatement , V ) ,
530
+ /// A statement whose values have already been bound (and thus serialized)
531
+ Bound ( BoundStatement ) ,
532
+ /// An unprepared statement with no values
533
+ Query ( Statement ) ,
534
+ }
535
+
536
+ impl From < BoundStatement > for BoundBatchStatement < ( ) > {
537
+ fn from ( b : BoundStatement ) -> Self {
538
+ BoundBatchStatement :: Bound ( b)
539
+ }
540
+ }
541
+
542
+ impl < V : SerializeRow > From < ( PreparedStatement , V ) > for BoundBatchStatement < V > {
543
+ fn from ( ( p, v) : ( PreparedStatement , V ) ) -> Self {
544
+ BoundBatchStatement :: Prepared ( p, v)
545
+ }
546
+ }
547
+
548
+ impl From < Statement > for BoundBatchStatement < ( ) > {
549
+ fn from ( s : Statement ) -> Self {
550
+ BoundBatchStatement :: Query ( s)
551
+ }
552
+ }
553
+
554
+ impl From < & str > for BoundBatchStatement < ( ) > {
555
+ fn from ( s : & str ) -> Self {
556
+ BoundBatchStatement :: Query ( Statement :: from ( s) )
557
+ }
558
+ }
559
+
560
+ /// An error type returned when adding a statement to a bounded batch fails
561
+ #[ non_exhaustive]
562
+ #[ derive( Error , Debug , Clone ) ]
563
+ pub enum BoundBatchStatementError {
564
+ /// Failed to serialize the batch statement
565
+ #[ error( transparent) ]
566
+ Statement ( #[ from] BatchStatementSerializationError ) ,
567
+ /// Failed to serialize statement's bound values.
568
+ #[ error( "Failed to calculate partition key" ) ]
569
+ PartitionKey ( #[ from] PartitionKeyError ) ,
570
+ /// Too many statements in the batch statement.
571
+ #[ error( "Added statement goes over exceeded max value of 65,535" ) ]
572
+ TooManyQueriesInBatchStatement ,
573
+ }
0 commit comments