Skip to content

Commit 0cc0531

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent f365843 commit 0cc0531

File tree

2 files changed

+216
-2
lines changed

2 files changed

+216
-2
lines changed

scylla/src/statement/batch.rs

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
88
};
99
use scylla_cql::frame::request;
1010
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
11-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
11+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1212
use scylla_cql::serialize::{RowWriter, SerializationError};
13+
use thiserror::Error;
1314

1415
use crate::client::execution_profile::ExecutionProfileHandle;
1516
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
@@ -266,6 +267,20 @@ impl BoundBatch {
266267
}
267268
}
268269

270+
/// Appends a new statement to the batch.
271+
pub fn append_statement<V: SerializeRow>(
272+
&mut self,
273+
statement: impl Into<BoundBatchStatement<V>>,
274+
) -> Result<(), BoundBatchStatementError> {
275+
let initial_len = self.buffer.len();
276+
self.raw_append_statement(statement).inspect_err(|_| {
277+
// if we error'd at any point we should put the buffer back to its old length to not
278+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
279+
// tries with a successful statement later
280+
self.buffer.truncate(initial_len);
281+
})
282+
}
283+
269284
pub(crate) fn from_batch(
270285
batch: &Batch,
271286
values: impl BatchValues,
@@ -372,6 +387,89 @@ impl BoundBatch {
372387
self.batch_type
373388
}
374389

390+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
391+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
392+
// `self` to be modified if an error occured because the caller will not reset them.
393+
fn raw_append_statement<V: SerializeRow>(
394+
&mut self,
395+
statement: impl Into<BoundBatchStatement<V>>,
396+
) -> Result<(), BoundBatchStatementError> {
397+
let mut statement = statement.into();
398+
let mut first_prepared = None;
399+
400+
if self.statements_len == 0 {
401+
// save it into a local variable for now in case a latter steps fails
402+
first_prepared = match statement {
403+
BoundBatchStatement::Bound(ref b) => {
404+
b.token()?.map(|token| (b.prepared.clone(), token))
405+
}
406+
BoundBatchStatement::Prepared(ps, values) => {
407+
let bound = ps
408+
.bind(&values)
409+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
410+
let first_prepared =
411+
bound.token()?.map(|token| (bound.prepared.clone(), token));
412+
// we already serialized it so to avoid re-serializing it, modify the statement to a
413+
// BoundStatement
414+
statement = BoundBatchStatement::Bound(bound);
415+
first_prepared
416+
}
417+
BoundBatchStatement::Query(_) => None,
418+
};
419+
}
420+
421+
let stmnt = match &statement {
422+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
423+
id: Cow::Borrowed(ps.get_id()),
424+
},
425+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
426+
id: Cow::Borrowed(b.prepared.get_id()),
427+
},
428+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
429+
text: Cow::Borrowed(&q.contents),
430+
},
431+
};
432+
433+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
434+
BoundBatchStatement::Prepared(ps, values) => {
435+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
436+
values.serialize(&ctx, writer).map(Some)
437+
}
438+
BoundBatchStatement::Bound(b) => {
439+
writer.append_serialize_row(&b.values);
440+
Ok(Some(()))
441+
}
442+
// query has no values
443+
BoundBatchStatement::Query(_) => Ok(Some(())),
444+
})?;
445+
446+
let new_statements_len = self
447+
.statements_len
448+
.checked_add(1)
449+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
450+
451+
/*** at this point nothing else should be fallible as we are going to be modifying
452+
* fields that do not get reset ***/
453+
454+
self.statements_len = new_statements_len;
455+
456+
if let Some(first_prepared) = first_prepared {
457+
self.first_prepared = Some(first_prepared);
458+
}
459+
460+
let prepared = match statement {
461+
BoundBatchStatement::Prepared(ps, _) => ps,
462+
BoundBatchStatement::Bound(b) => b.prepared,
463+
BoundBatchStatement::Query(_) => return Ok(()),
464+
};
465+
466+
if !self.prepared.contains_key(prepared.get_id()) {
467+
self.prepared.insert(prepared.get_id().to_owned(), prepared);
468+
}
469+
470+
Ok(())
471+
}
472+
375473
fn serialize_from_batch_statement<T>(
376474
&mut self,
377475
statement: &BatchStatement,
@@ -446,3 +544,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
446544
n_statements: n_statements as usize,
447545
}
448546
}
547+
548+
/// This enum represents a CQL statement, that can be part of batch and its values
549+
#[derive(Clone)]
550+
#[non_exhaustive]
551+
pub enum BoundBatchStatement<V: SerializeRow> {
552+
/// A prepared statement and its not-yet serialized values
553+
Prepared(PreparedStatement, V),
554+
/// A statement whose values have already been bound (and thus serialized)
555+
Bound(BoundStatement),
556+
/// An unprepared statement with no values
557+
Query(Statement),
558+
}
559+
560+
impl From<BoundStatement> for BoundBatchStatement<()> {
561+
fn from(b: BoundStatement) -> Self {
562+
BoundBatchStatement::Bound(b)
563+
}
564+
}
565+
566+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<V> {
567+
fn from((p, v): (PreparedStatement, V)) -> Self {
568+
BoundBatchStatement::Prepared(p, v)
569+
}
570+
}
571+
572+
impl From<Statement> for BoundBatchStatement<()> {
573+
fn from(s: Statement) -> Self {
574+
BoundBatchStatement::Query(s)
575+
}
576+
}
577+
578+
impl From<&str> for BoundBatchStatement<()> {
579+
fn from(s: &str) -> Self {
580+
BoundBatchStatement::Query(Statement::from(s))
581+
}
582+
}
583+
584+
/// An error type returned when adding a statement to a bounded batch fails
585+
#[non_exhaustive]
586+
#[derive(Error, Debug, Clone)]
587+
pub enum BoundBatchStatementError {
588+
/// Failed to serialize the batch statement
589+
#[error(transparent)]
590+
Statement(#[from] BatchStatementSerializationError),
591+
/// Failed to serialize statement's bound values.
592+
#[error("Failed to calculate partition key")]
593+
PartitionKey(#[from] PartitionKeyError),
594+
/// Too many statements in the batch statement.
595+
#[error("Added statement goes over exceeded max value of 65,535")]
596+
TooManyQueriesInBatchStatement,
597+
}

scylla/tests/integration/session.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use scylla::routing::partitioner::PartitionerName;
2424
use scylla::routing::Token;
2525
use scylla::serialize::row::SerializeRow;
2626
use scylla::serialize::value::SerializeValue;
27-
use scylla::statement::batch::{Batch, BatchStatement, BatchType};
27+
use scylla::statement::batch::{Batch, BatchStatement, BatchType, BoundBatch};
28+
use scylla::statement::execute::Execute;
2829
use scylla::statement::prepared::PreparedStatement;
2930
use scylla::statement::unprepared::Statement;
3031
use scylla::statement::Consistency;
@@ -3269,3 +3270,67 @@ async fn test_vector_type_prepared() {
32693270

32703271
// TODO: Implement and test SELECT statements and bind values (`?`)
32713272
}
3273+
3274+
#[tokio::test]
3275+
async fn test_bound_batch() {
3276+
setup_tracing();
3277+
let session = Arc::new(create_new_session_builder().build().await.unwrap());
3278+
let ks = unique_keyspace_name();
3279+
3280+
session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap();
3281+
session
3282+
.ddl(format!(
3283+
"CREATE TABLE IF NOT EXISTS {}.t_batch (a int, b int, c text, primary key (a, b))",
3284+
ks
3285+
))
3286+
.await
3287+
.unwrap();
3288+
3289+
let prepared_statement = session
3290+
.prepare(format!(
3291+
"INSERT INTO {}.t_batch (a, b, c) VALUES (?, ?, ?)",
3292+
ks
3293+
))
3294+
.await
3295+
.unwrap();
3296+
3297+
let four_value: i32 = 4;
3298+
let hello_value: String = String::from("hello");
3299+
3300+
let bound_statement = prepared_statement
3301+
.clone()
3302+
.bind(&(1_i32, &four_value, hello_value.as_str()))
3303+
.unwrap();
3304+
3305+
let mut batch: BoundBatch = Default::default();
3306+
batch
3307+
.append_statement((prepared_statement, (1_i32, 2_i32, "abc")))
3308+
.unwrap();
3309+
batch
3310+
.append_statement(&format!("INSERT INTO {}.t_batch (a, b, c) VALUES (7, 11, '')", ks)[..])
3311+
.unwrap();
3312+
batch.append_statement(bound_statement).unwrap();
3313+
3314+
batch.execute(&session).await.unwrap();
3315+
3316+
let mut results: Vec<(i32, i32, String)> = session
3317+
.query_unpaged(format!("SELECT a, b, c FROM {}.t_batch", ks), &[])
3318+
.await
3319+
.unwrap()
3320+
.into_rows_result()
3321+
.unwrap()
3322+
.rows::<(i32, i32, String)>()
3323+
.unwrap()
3324+
.collect::<Result<_, _>>()
3325+
.unwrap();
3326+
3327+
results.sort();
3328+
assert_eq!(
3329+
results,
3330+
vec![
3331+
(1, 2, String::from("abc")),
3332+
(1, 4, String::from("hello")),
3333+
(7, 11, String::from(""))
3334+
]
3335+
);
3336+
}

0 commit comments

Comments
 (0)