Skip to content

Commit e86afc9

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent 52c78ad commit e86afc9

File tree

8 files changed

+655
-365
lines changed

8 files changed

+655
-365
lines changed

scylla/src/client/session.rs

Lines changed: 21 additions & 351 deletions
Large diffs are not rendered by default.

scylla/src/network/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ impl Connection {
10381038
consistency,
10391039
serial_consistency,
10401040
timestamp,
1041-
statements_len: batch.statements_len,
1041+
statements_len: batch.statements_len(),
10421042
};
10431043

10441044
loop {

scylla/src/statement/batch.rs

Lines changed: 243 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,28 @@ use scylla_cql::frame::frame_errors::{
88
};
99
use scylla_cql::frame::request;
1010
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
11-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
11+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1212
use scylla_cql::serialize::{RowWriter, SerializationError};
13+
use thiserror::Error;
14+
use tracing::Instrument;
1315

14-
use crate::client::execution_profile::ExecutionProfileHandle;
16+
use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner};
17+
use crate::client::session::{RunRequestResult, Session};
1518
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
19+
use crate::network::Connection;
20+
use crate::observability::driver_tracing::RequestSpan;
1621
use crate::observability::history::HistoryListener;
1722
use crate::policies::load_balancing::LoadBalancingPolicy;
23+
use crate::policies::load_balancing::RoutingInfo;
1824
use crate::policies::retry::RetryPolicy;
25+
use crate::response::query_result::QueryResult;
26+
use crate::response::{Coordinator, NonErrorQueryResponse, QueryResponse};
1927
use crate::routing::Token;
2028
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
2129
use crate::statement::unprepared::Statement;
2230

2331
use super::bound::BoundStatement;
32+
use super::execute::Execute;
2433
use super::StatementConfig;
2534
use super::{Consistency, SerialConsistency};
2635
pub use crate::frame::request::batch::BatchType;
@@ -254,8 +263,8 @@ pub struct BoundBatch {
254263
batch_type: BatchType,
255264
pub(crate) buffer: Vec<u8>,
256265
pub(crate) prepared: HashMap<Bytes, PreparedStatement>,
257-
pub(crate) first_prepared: Option<(PreparedStatement, Token)>,
258-
pub(crate) statements_len: u16,
266+
first_prepared: Option<(PreparedStatement, Token)>,
267+
statements_len: u16,
259268
}
260269

261270
impl BoundBatch {
@@ -266,6 +275,20 @@ impl BoundBatch {
266275
}
267276
}
268277

278+
/// Appends a new statement to the batch.
279+
pub fn append_statement<'p, V: SerializeRow>(
280+
&mut self,
281+
statement: impl Into<BoundBatchStatement<'p, V>>,
282+
) -> Result<(), BoundBatchStatementError> {
283+
let initial_len = self.buffer.len();
284+
self.raw_append_statement(statement).inspect_err(|_| {
285+
// if we error'd at any point we should put the buffer back to its old length to not
286+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
287+
// tries with a successful statement later
288+
self.buffer.truncate(initial_len);
289+
})
290+
}
291+
269292
#[allow(clippy::result_large_err)]
270293
pub(crate) fn from_batch(
271294
batch: &Batch,
@@ -375,6 +398,96 @@ impl BoundBatch {
375398
self.batch_type
376399
}
377400

401+
pub fn statements_len(&self) -> u16 {
402+
self.statements_len
403+
}
404+
405+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
406+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
407+
// `self` to be modified if an error occured because the caller will not reset them.
408+
fn raw_append_statement<'p, V: SerializeRow>(
409+
&mut self,
410+
statement: impl Into<BoundBatchStatement<'p, V>>,
411+
) -> Result<(), BoundBatchStatementError> {
412+
let mut statement = statement.into();
413+
let mut first_prepared = None;
414+
415+
if self.statements_len == 0 {
416+
// save it into a local variable for now in case a latter steps fails
417+
first_prepared = match statement {
418+
BoundBatchStatement::Bound(ref b) => b
419+
.token()?
420+
.map(|token| (b.prepared.clone().into_owned(), token)),
421+
BoundBatchStatement::Prepared(ps, values) => {
422+
let bound = ps
423+
.into_bind(&values)
424+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
425+
let first_prepared = bound
426+
.token()?
427+
.map(|token| (bound.prepared.clone().into_owned(), token));
428+
// we already serialized it so to avoid re-serializing it, modify the statement to a
429+
// BoundStatement
430+
statement = BoundBatchStatement::Bound(bound);
431+
first_prepared
432+
}
433+
BoundBatchStatement::Query(_) => None,
434+
};
435+
}
436+
437+
let stmnt = match &statement {
438+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
439+
id: Cow::Borrowed(ps.get_id()),
440+
},
441+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
442+
id: Cow::Borrowed(b.prepared.get_id()),
443+
},
444+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
445+
text: Cow::Borrowed(&q.contents),
446+
},
447+
};
448+
449+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
450+
BoundBatchStatement::Prepared(ps, values) => {
451+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
452+
values.serialize(&ctx, writer).map(Some)
453+
}
454+
BoundBatchStatement::Bound(b) => {
455+
writer.append_serialize_row(&b.values);
456+
Ok(Some(()))
457+
}
458+
// query has no values
459+
BoundBatchStatement::Query(_) => Ok(Some(())),
460+
})?;
461+
462+
let new_statements_len = self
463+
.statements_len
464+
.checked_add(1)
465+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
466+
467+
/*** at this point nothing else should be fallible as we are going to be modifying
468+
* fields that do not get reset ***/
469+
470+
self.statements_len = new_statements_len;
471+
472+
if let Some(first_prepared) = first_prepared {
473+
self.first_prepared = Some(first_prepared);
474+
}
475+
476+
let prepared = match statement {
477+
BoundBatchStatement::Prepared(ps, _) => Cow::Owned(ps),
478+
BoundBatchStatement::Bound(b) => b.prepared,
479+
BoundBatchStatement::Query(_) => return Ok(()),
480+
};
481+
482+
if !self.prepared.contains_key(prepared.get_id()) {
483+
self.prepared
484+
.insert(prepared.get_id().to_owned(), prepared.into_owned());
485+
}
486+
487+
Ok(())
488+
}
489+
490+
#[allow(clippy::result_large_err)]
378491
fn serialize_from_batch_statement<T>(
379492
&mut self,
380493
statement: &BatchStatement,
@@ -449,3 +562,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
449562
n_statements: n_statements as usize,
450563
}
451564
}
565+
566+
/// This enum represents a CQL statement, that can be part of batch and its values
567+
#[derive(Clone)]
568+
#[non_exhaustive]
569+
pub enum BoundBatchStatement<'p, V: SerializeRow> {
570+
/// A prepared statement and its not-yet serialized values
571+
Prepared(PreparedStatement, V),
572+
/// A statement whose values have already been bound (and thus serialized)
573+
Bound(BoundStatement<'p>),
574+
/// An unprepared statement with no values
575+
Query(Statement),
576+
}
577+
578+
impl<'p> From<BoundStatement<'p>> for BoundBatchStatement<'p, ()> {
579+
fn from(b: BoundStatement<'p>) -> Self {
580+
BoundBatchStatement::Bound(b)
581+
}
582+
}
583+
584+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<'static, V> {
585+
fn from((p, v): (PreparedStatement, V)) -> Self {
586+
BoundBatchStatement::Prepared(p, v)
587+
}
588+
}
589+
590+
impl From<Statement> for BoundBatchStatement<'static, ()> {
591+
fn from(s: Statement) -> Self {
592+
BoundBatchStatement::Query(s)
593+
}
594+
}
595+
596+
impl From<&str> for BoundBatchStatement<'static, ()> {
597+
fn from(s: &str) -> Self {
598+
BoundBatchStatement::Query(Statement::from(s))
599+
}
600+
}
601+
602+
/// An error type returned when adding a statement to a bounded batch fails
603+
#[non_exhaustive]
604+
#[derive(Error, Debug, Clone)]
605+
pub enum BoundBatchStatementError {
606+
/// Failed to serialize the batch statement
607+
#[error(transparent)]
608+
Statement(#[from] BatchStatementSerializationError),
609+
/// Failed to serialize statement's bound values.
610+
#[error("Failed to calculate partition key")]
611+
PartitionKey(#[from] PartitionKeyError),
612+
/// Too many statements in the batch statement.
613+
#[error("Added statement goes over exceeded max value of 65,535")]
614+
TooManyQueriesInBatchStatement,
615+
}
616+
617+
impl Execute for BoundBatch {
618+
async fn execute(&self, session: &Session) -> Result<QueryResult, ExecutionError> {
619+
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
620+
// If users batch statements by shard, they will be rewarded with full shard awareness
621+
let execution_profile = self
622+
.get_execution_profile_handle()
623+
.unwrap_or_else(|| session.get_default_execution_profile_handle())
624+
.access();
625+
626+
let consistency = self
627+
.config
628+
.consistency
629+
.unwrap_or(execution_profile.consistency);
630+
631+
let serial_consistency = self
632+
.config
633+
.serial_consistency
634+
.unwrap_or(execution_profile.serial_consistency);
635+
636+
let (table, token) = self
637+
.first_prepared
638+
.as_ref()
639+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
640+
.unzip();
641+
642+
let statement_info = RoutingInfo {
643+
consistency,
644+
serial_consistency,
645+
token,
646+
table,
647+
is_confirmed_lwt: false,
648+
};
649+
650+
let span = RequestSpan::new_batch();
651+
652+
let (run_request_result, coordinator): (
653+
RunRequestResult<NonErrorQueryResponse>,
654+
Coordinator,
655+
) = session
656+
.run_request(
657+
statement_info,
658+
&self.config,
659+
execution_profile,
660+
|connection: Arc<Connection>,
661+
consistency: Consistency,
662+
execution_profile: &ExecutionProfileInner| {
663+
let serial_consistency = self
664+
.config
665+
.serial_consistency
666+
.unwrap_or(execution_profile.serial_consistency);
667+
async move {
668+
connection
669+
.batch_with_consistency(self, consistency, serial_consistency)
670+
.await
671+
.and_then(QueryResponse::into_non_error_query_response)
672+
}
673+
},
674+
&span,
675+
)
676+
.instrument(span.span().clone())
677+
.await?;
678+
679+
let result = match run_request_result {
680+
RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(coordinator),
681+
RunRequestResult::Completed(non_error_query_response) => {
682+
let result = non_error_query_response.into_query_result(coordinator)?;
683+
span.record_result_fields(&result);
684+
result
685+
}
686+
};
687+
688+
Ok(result)
689+
}
690+
}

0 commit comments

Comments
 (0)