Skip to content

Commit 611f2b3

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent 55c0111 commit 611f2b3

File tree

2 files changed

+216
-2
lines changed

2 files changed

+216
-2
lines changed

scylla/src/statement/batch.rs

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use scylla_cql::frame::frame_errors::{
88
};
99
use scylla_cql::frame::request;
1010
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
11-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
11+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1212
use scylla_cql::serialize::{RowWriter, SerializationError};
13+
use thiserror::Error;
1314

1415
use crate::client::execution_profile::ExecutionProfileHandle;
1516
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
@@ -242,6 +243,20 @@ impl BoundBatch {
242243
}
243244
}
244245

246+
/// Appends a new statement to the batch.
247+
pub fn append_statement<V: SerializeRow>(
248+
&mut self,
249+
statement: impl Into<BoundBatchStatement<V>>,
250+
) -> Result<(), BoundBatchStatementError> {
251+
let initial_len = self.buffer.len();
252+
self.raw_append_statement(statement).inspect_err(|_| {
253+
// if we error'd at any point we should put the buffer back to its old length to not
254+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
255+
// tries with a successful statement later
256+
self.buffer.truncate(initial_len);
257+
})
258+
}
259+
245260
pub(crate) fn from_batch(
246261
batch: &Batch,
247262
values: impl BatchValues,
@@ -348,6 +363,89 @@ impl BoundBatch {
348363
self.batch_type
349364
}
350365

366+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
367+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
368+
// `self` to be modified if an error occured because the caller will not reset them.
369+
fn raw_append_statement<V: SerializeRow>(
370+
&mut self,
371+
statement: impl Into<BoundBatchStatement<V>>,
372+
) -> Result<(), BoundBatchStatementError> {
373+
let mut statement = statement.into();
374+
let mut first_prepared = None;
375+
376+
if self.statements_len == 0 {
377+
// save it into a local variable for now in case a latter steps fails
378+
first_prepared = match statement {
379+
BoundBatchStatement::Bound(ref b) => {
380+
b.token()?.map(|token| (b.prepared.clone(), token))
381+
}
382+
BoundBatchStatement::Prepared(ps, values) => {
383+
let bound = ps
384+
.bind(&values)
385+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
386+
let first_prepared =
387+
bound.token()?.map(|token| (bound.prepared.clone(), token));
388+
// we already serialized it so to avoid re-serializing it, modify the statement to a
389+
// BoundStatement
390+
statement = BoundBatchStatement::Bound(bound);
391+
first_prepared
392+
}
393+
BoundBatchStatement::Query(_) => None,
394+
};
395+
}
396+
397+
let stmnt = match &statement {
398+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
399+
id: Cow::Borrowed(ps.get_id()),
400+
},
401+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
402+
id: Cow::Borrowed(b.prepared.get_id()),
403+
},
404+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
405+
text: Cow::Borrowed(&q.contents),
406+
},
407+
};
408+
409+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
410+
BoundBatchStatement::Prepared(ps, values) => {
411+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
412+
values.serialize(&ctx, writer).map(Some)
413+
}
414+
BoundBatchStatement::Bound(b) => {
415+
writer.append_serialize_row(&b.values);
416+
Ok(Some(()))
417+
}
418+
// query has no values
419+
BoundBatchStatement::Query(_) => Ok(Some(())),
420+
})?;
421+
422+
let new_statements_len = self
423+
.statements_len
424+
.checked_add(1)
425+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
426+
427+
/*** at this point nothing else should be fallible as we are going to be modifying
428+
* fields that do not get reset ***/
429+
430+
self.statements_len = new_statements_len;
431+
432+
if let Some(first_prepared) = first_prepared {
433+
self.first_prepared = Some(first_prepared);
434+
}
435+
436+
let prepared = match statement {
437+
BoundBatchStatement::Prepared(ps, _) => ps,
438+
BoundBatchStatement::Bound(b) => b.prepared,
439+
BoundBatchStatement::Query(_) => return Ok(()),
440+
};
441+
442+
if !self.prepared.contains_key(prepared.get_id()) {
443+
self.prepared.insert(prepared.get_id().to_owned(), prepared);
444+
}
445+
446+
Ok(())
447+
}
448+
351449
fn serialize_from_batch_statement<T>(
352450
&mut self,
353451
statement: &BatchStatement,
@@ -422,3 +520,54 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
422520
n_statements: n_statements as usize,
423521
}
424522
}
523+
524+
/// This enum represents a CQL statement, that can be part of batch and its values
525+
#[derive(Clone)]
526+
#[non_exhaustive]
527+
pub enum BoundBatchStatement<V: SerializeRow> {
528+
/// A prepared statement and its not-yet serialized values
529+
Prepared(PreparedStatement, V),
530+
/// A statement whose values have already been bound (and thus serialized)
531+
Bound(BoundStatement),
532+
/// An unprepared statement with no values
533+
Query(Statement),
534+
}
535+
536+
impl From<BoundStatement> for BoundBatchStatement<()> {
537+
fn from(b: BoundStatement) -> Self {
538+
BoundBatchStatement::Bound(b)
539+
}
540+
}
541+
542+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<V> {
543+
fn from((p, v): (PreparedStatement, V)) -> Self {
544+
BoundBatchStatement::Prepared(p, v)
545+
}
546+
}
547+
548+
impl From<Statement> for BoundBatchStatement<()> {
549+
fn from(s: Statement) -> Self {
550+
BoundBatchStatement::Query(s)
551+
}
552+
}
553+
554+
impl From<&str> for BoundBatchStatement<()> {
555+
fn from(s: &str) -> Self {
556+
BoundBatchStatement::Query(Statement::from(s))
557+
}
558+
}
559+
560+
/// An error type returned when adding a statement to a bounded batch fails
561+
#[non_exhaustive]
562+
#[derive(Error, Debug, Clone)]
563+
pub enum BoundBatchStatementError {
564+
/// Failed to serialize the batch statement
565+
#[error(transparent)]
566+
Statement(#[from] BatchStatementSerializationError),
567+
/// Failed to serialize statement's bound values.
568+
#[error("Failed to calculate partition key")]
569+
PartitionKey(#[from] PartitionKeyError),
570+
/// Too many statements in the batch statement.
571+
#[error("Added statement goes over exceeded max value of 65,535")]
572+
TooManyQueriesInBatchStatement,
573+
}

scylla/tests/integration/session.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use scylla::routing::partitioner::PartitionerName;
2424
use scylla::routing::Token;
2525
use scylla::serialize::row::SerializeRow;
2626
use scylla::serialize::value::SerializeValue;
27-
use scylla::statement::batch::{Batch, BatchStatement, BatchType};
27+
use scylla::statement::batch::{Batch, BatchStatement, BatchType, BoundBatch};
28+
use scylla::statement::execute::Execute;
2829
use scylla::statement::prepared::PreparedStatement;
2930
use scylla::statement::unprepared::Statement;
3031
use scylla::statement::Consistency;
@@ -3269,3 +3270,67 @@ async fn test_vector_type_prepared() {
32693270

32703271
// TODO: Implement and test SELECT statements and bind values (`?`)
32713272
}
3273+
3274+
#[tokio::test]
3275+
async fn test_bound_batch() {
3276+
setup_tracing();
3277+
let session = Arc::new(create_new_session_builder().build().await.unwrap());
3278+
let ks = unique_keyspace_name();
3279+
3280+
session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap();
3281+
session
3282+
.ddl(format!(
3283+
"CREATE TABLE IF NOT EXISTS {}.t_batch (a int, b int, c text, primary key (a, b))",
3284+
ks
3285+
))
3286+
.await
3287+
.unwrap();
3288+
3289+
let prepared_statement = session
3290+
.prepare(format!(
3291+
"INSERT INTO {}.t_batch (a, b, c) VALUES (?, ?, ?)",
3292+
ks
3293+
))
3294+
.await
3295+
.unwrap();
3296+
3297+
let four_value: i32 = 4;
3298+
let hello_value: String = String::from("hello");
3299+
3300+
let bound_statement = prepared_statement
3301+
.clone()
3302+
.bind(&(1_i32, &four_value, hello_value.as_str()))
3303+
.unwrap();
3304+
3305+
let mut batch: BoundBatch = Default::default();
3306+
batch
3307+
.append_statement((prepared_statement, (1_i32, 2_i32, "abc")))
3308+
.unwrap();
3309+
batch
3310+
.append_statement(&format!("INSERT INTO {}.t_batch (a, b, c) VALUES (7, 11, '')", ks)[..])
3311+
.unwrap();
3312+
batch.append_statement(bound_statement).unwrap();
3313+
3314+
batch.execute(&session).await.unwrap();
3315+
3316+
let mut results: Vec<(i32, i32, String)> = session
3317+
.query_unpaged(format!("SELECT a, b, c FROM {}.t_batch", ks), &[])
3318+
.await
3319+
.unwrap()
3320+
.into_rows_result()
3321+
.unwrap()
3322+
.rows::<(i32, i32, String)>()
3323+
.unwrap()
3324+
.collect::<Result<_, _>>()
3325+
.unwrap();
3326+
3327+
results.sort();
3328+
assert_eq!(
3329+
results,
3330+
vec![
3331+
(1, 2, String::from("abc")),
3332+
(1, 4, String::from("hello")),
3333+
(7, 11, String::from(""))
3334+
]
3335+
);
3336+
}

0 commit comments

Comments
 (0)