Skip to content

Commit dca585a

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent 3cfdf5b commit dca585a

File tree

8 files changed

+654
-359
lines changed

8 files changed

+654
-359
lines changed

scylla/src/client/session.rs

Lines changed: 22 additions & 345 deletions
Large diffs are not rendered by default.

scylla/src/network/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ impl Connection {
11101110
consistency,
11111111
serial_consistency,
11121112
timestamp,
1113-
statements_len: batch.statements_len,
1113+
statements_len: batch.statements_len(),
11141114
};
11151115

11161116
loop {

scylla/src/statement/batch.rs

Lines changed: 239 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,28 @@ use scylla_cql::frame::frame_errors::{
88
};
99
use scylla_cql::frame::request;
1010
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
11-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
11+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1212
use scylla_cql::serialize::{RowWriter, SerializationError};
13+
use thiserror::Error;
14+
use tracing::Instrument;
1315

14-
use crate::client::execution_profile::ExecutionProfileHandle;
16+
use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner};
17+
use crate::client::session::{RunRequestResult, Session};
1518
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
19+
use crate::network::Connection;
20+
use crate::observability::driver_tracing::RequestSpan;
1621
use crate::observability::history::HistoryListener;
1722
use crate::policies::load_balancing::LoadBalancingPolicy;
23+
use crate::policies::load_balancing::RoutingInfo;
1824
use crate::policies::retry::RetryPolicy;
25+
use crate::response::query_result::QueryResult;
26+
use crate::response::{NonErrorQueryResponse, QueryResponse};
1927
use crate::routing::Token;
2028
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
2129
use crate::statement::unprepared::Statement;
2230

2331
use super::bound::BoundStatement;
32+
use super::execute::Execute;
2433
use super::StatementConfig;
2534
use super::{Consistency, SerialConsistency};
2635
pub use crate::frame::request::batch::BatchType;
@@ -254,8 +263,8 @@ pub struct BoundBatch {
254263
batch_type: BatchType,
255264
pub(crate) buffer: Vec<u8>,
256265
pub(crate) prepared: HashMap<Bytes, PreparedStatement>,
257-
pub(crate) first_prepared: Option<(PreparedStatement, Token)>,
258-
pub(crate) statements_len: u16,
266+
first_prepared: Option<(PreparedStatement, Token)>,
267+
statements_len: u16,
259268
}
260269

261270
impl BoundBatch {
@@ -266,6 +275,20 @@ impl BoundBatch {
266275
}
267276
}
268277

278+
/// Appends a new statement to the batch.
279+
pub fn append_statement<'p, V: SerializeRow>(
280+
&mut self,
281+
statement: impl Into<BoundBatchStatement<'p, V>>,
282+
) -> Result<(), BoundBatchStatementError> {
283+
let initial_len = self.buffer.len();
284+
self.raw_append_statement(statement).inspect_err(|_| {
285+
// if we error'd at any point we should put the buffer back to its old length to not
286+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
287+
// tries with a successful statement later
288+
self.buffer.truncate(initial_len);
289+
})
290+
}
291+
269292
pub(crate) fn from_batch(
270293
batch: &Batch,
271294
values: impl BatchValues,
@@ -374,6 +397,95 @@ impl BoundBatch {
374397
self.batch_type
375398
}
376399

400+
pub fn statements_len(&self) -> u16 {
401+
self.statements_len
402+
}
403+
404+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
405+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
406+
// `self` to be modified if an error occured because the caller will not reset them.
407+
fn raw_append_statement<'p, V: SerializeRow>(
408+
&mut self,
409+
statement: impl Into<BoundBatchStatement<'p, V>>,
410+
) -> Result<(), BoundBatchStatementError> {
411+
let mut statement = statement.into();
412+
let mut first_prepared = None;
413+
414+
if self.statements_len == 0 {
415+
// save it into a local variable for now in case a latter steps fails
416+
first_prepared = match statement {
417+
BoundBatchStatement::Bound(ref b) => b
418+
.token()?
419+
.map(|token| (b.prepared.clone().into_owned(), token)),
420+
BoundBatchStatement::Prepared(ps, values) => {
421+
let bound = ps
422+
.into_bind(&values)
423+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
424+
let first_prepared = bound
425+
.token()?
426+
.map(|token| (bound.prepared.clone().into_owned(), token));
427+
// we already serialized it so to avoid re-serializing it, modify the statement to a
428+
// BoundStatement
429+
statement = BoundBatchStatement::Bound(bound);
430+
first_prepared
431+
}
432+
BoundBatchStatement::Query(_) => None,
433+
};
434+
}
435+
436+
let stmnt = match &statement {
437+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
438+
id: Cow::Borrowed(ps.get_id()),
439+
},
440+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
441+
id: Cow::Borrowed(b.prepared.get_id()),
442+
},
443+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
444+
text: Cow::Borrowed(&q.contents),
445+
},
446+
};
447+
448+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
449+
BoundBatchStatement::Prepared(ps, values) => {
450+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
451+
values.serialize(&ctx, writer).map(Some)
452+
}
453+
BoundBatchStatement::Bound(b) => {
454+
writer.append_serialize_row(&b.values);
455+
Ok(Some(()))
456+
}
457+
// query has no values
458+
BoundBatchStatement::Query(_) => Ok(Some(())),
459+
})?;
460+
461+
let new_statements_len = self
462+
.statements_len
463+
.checked_add(1)
464+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
465+
466+
/*** at this point nothing else should be fallible as we are going to be modifying
467+
* fields that do not get reset ***/
468+
469+
self.statements_len = new_statements_len;
470+
471+
if let Some(first_prepared) = first_prepared {
472+
self.first_prepared = Some(first_prepared);
473+
}
474+
475+
let prepared = match statement {
476+
BoundBatchStatement::Prepared(ps, _) => Cow::Owned(ps),
477+
BoundBatchStatement::Bound(b) => b.prepared,
478+
BoundBatchStatement::Query(_) => return Ok(()),
479+
};
480+
481+
if !self.prepared.contains_key(prepared.get_id()) {
482+
self.prepared
483+
.insert(prepared.get_id().to_owned(), prepared.into_owned());
484+
}
485+
486+
Ok(())
487+
}
488+
377489
fn serialize_from_batch_statement<T>(
378490
&mut self,
379491
statement: &BatchStatement,
@@ -448,3 +560,126 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
448560
n_statements: n_statements as usize,
449561
}
450562
}
563+
564+
/// This enum represents a CQL statement, that can be part of batch and its values
565+
#[derive(Clone)]
566+
#[non_exhaustive]
567+
pub enum BoundBatchStatement<'p, V: SerializeRow> {
568+
/// A prepared statement and its not-yet serialized values
569+
Prepared(PreparedStatement, V),
570+
/// A statement whose values have already been bound (and thus serialized)
571+
Bound(BoundStatement<'p>),
572+
/// An unprepared statement with no values
573+
Query(Statement),
574+
}
575+
576+
impl<'p> From<BoundStatement<'p>> for BoundBatchStatement<'p, ()> {
577+
fn from(b: BoundStatement<'p>) -> Self {
578+
BoundBatchStatement::Bound(b)
579+
}
580+
}
581+
582+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<'static, V> {
583+
fn from((p, v): (PreparedStatement, V)) -> Self {
584+
BoundBatchStatement::Prepared(p, v)
585+
}
586+
}
587+
588+
impl From<Statement> for BoundBatchStatement<'static, ()> {
589+
fn from(s: Statement) -> Self {
590+
BoundBatchStatement::Query(s)
591+
}
592+
}
593+
594+
impl From<&str> for BoundBatchStatement<'static, ()> {
595+
fn from(s: &str) -> Self {
596+
BoundBatchStatement::Query(Statement::from(s))
597+
}
598+
}
599+
600+
/// An error type returned when adding a statement to a bounded batch fails
601+
#[non_exhaustive]
602+
#[derive(Error, Debug, Clone)]
603+
pub enum BoundBatchStatementError {
604+
/// Failed to serialize the batch statement
605+
#[error(transparent)]
606+
Statement(#[from] BatchStatementSerializationError),
607+
/// Failed to serialize statement's bound values.
608+
#[error("Failed to calculate partition key")]
609+
PartitionKey(#[from] PartitionKeyError),
610+
/// Too many statements in the batch statement.
611+
#[error("Added statement goes over exceeded max value of 65,535")]
612+
TooManyQueriesInBatchStatement,
613+
}
614+
615+
impl Execute for BoundBatch {
616+
async fn execute(&self, session: &Session) -> Result<QueryResult, ExecutionError> {
617+
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
618+
// If users batch statements by shard, they will be rewarded with full shard awareness
619+
let execution_profile = self
620+
.get_execution_profile_handle()
621+
.unwrap_or_else(|| session.get_default_execution_profile_handle())
622+
.access();
623+
624+
let consistency = self
625+
.config
626+
.consistency
627+
.unwrap_or(execution_profile.consistency);
628+
629+
let serial_consistency = self
630+
.config
631+
.serial_consistency
632+
.unwrap_or(execution_profile.serial_consistency);
633+
634+
let (table, token) = self
635+
.first_prepared
636+
.as_ref()
637+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
638+
.unzip();
639+
640+
let statement_info = RoutingInfo {
641+
consistency,
642+
serial_consistency,
643+
token,
644+
table,
645+
is_confirmed_lwt: false,
646+
};
647+
648+
let span = RequestSpan::new_batch();
649+
650+
let run_request_result: RunRequestResult<NonErrorQueryResponse> = session
651+
.run_request(
652+
statement_info,
653+
&self.config,
654+
execution_profile,
655+
|connection: Arc<Connection>,
656+
consistency: Consistency,
657+
execution_profile: &ExecutionProfileInner| {
658+
let serial_consistency = self
659+
.config
660+
.serial_consistency
661+
.unwrap_or(execution_profile.serial_consistency);
662+
async move {
663+
connection
664+
.batch_with_consistency(self, consistency, serial_consistency)
665+
.await
666+
.and_then(QueryResponse::into_non_error_query_response)
667+
}
668+
},
669+
&span,
670+
)
671+
.instrument(span.span().clone())
672+
.await?;
673+
674+
let result = match run_request_result {
675+
RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(),
676+
RunRequestResult::Completed(non_error_query_response) => {
677+
let result = non_error_query_response.into_query_result()?;
678+
span.record_result_fields(&result);
679+
result
680+
}
681+
};
682+
683+
Ok(result)
684+
}
685+
}

0 commit comments

Comments
 (0)