Skip to content

Commit af88283

Browse files
committed
Allow adding statements to a boundbatch and executing it
1 parent 8979702 commit af88283

File tree

8 files changed

+653
-365
lines changed

8 files changed

+653
-365
lines changed

scylla/src/client/session.rs

Lines changed: 21 additions & 351 deletions
Large diffs are not rendered by default.

scylla/src/network/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ impl Connection {
10381038
consistency,
10391039
serial_consistency,
10401040
timestamp,
1041-
statements_len: batch.statements_len,
1041+
statements_len: batch.statements_len(),
10421042
};
10431043

10441044
loop {

scylla/src/statement/batch.rs

Lines changed: 243 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,28 @@ use scylla_cql::frame::frame_errors::{
99
};
1010
use scylla_cql::frame::request;
1111
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
12-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
12+
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
1313
use scylla_cql::serialize::{RowWriter, SerializationError};
14+
use thiserror::Error;
15+
use tracing::Instrument;
1416

15-
use crate::client::execution_profile::ExecutionProfileHandle;
17+
use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner};
18+
use crate::client::session::{RunRequestResult, Session};
1619
use crate::errors::{BadQuery, ExecutionError, RequestAttemptError};
20+
use crate::network::Connection;
21+
use crate::observability::driver_tracing::RequestSpan;
1722
use crate::observability::history::HistoryListener;
1823
use crate::policies::load_balancing::LoadBalancingPolicy;
24+
use crate::policies::load_balancing::RoutingInfo;
1925
use crate::policies::retry::RetryPolicy;
26+
use crate::response::query_result::QueryResult;
27+
use crate::response::{Coordinator, NonErrorQueryResponse, QueryResponse};
2028
use crate::routing::Token;
2129
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
2230
use crate::statement::unprepared::Statement;
2331

2432
use super::bound::BoundStatement;
33+
use super::execute::Execute;
2534
use super::StatementConfig;
2635
use super::{Consistency, SerialConsistency};
2736
pub use crate::frame::request::batch::BatchType;
@@ -282,8 +291,8 @@ pub struct BoundBatch {
282291
batch_type: BatchType,
283292
pub(crate) buffer: Vec<u8>,
284293
pub(crate) prepared: HashMap<Bytes, PreparedStatement>,
285-
pub(crate) first_prepared: Option<(PreparedStatement, Token)>,
286-
pub(crate) statements_len: u16,
294+
first_prepared: Option<(PreparedStatement, Token)>,
295+
statements_len: u16,
287296
}
288297

289298
impl BoundBatch {
@@ -294,6 +303,20 @@ impl BoundBatch {
294303
}
295304
}
296305

306+
/// Appends a new statement to the batch.
307+
pub fn append_statement<'p, V: SerializeRow>(
308+
&mut self,
309+
statement: impl Into<BoundBatchStatement<'p, V>>,
310+
) -> Result<(), BoundBatchStatementError> {
311+
let initial_len = self.buffer.len();
312+
self.raw_append_statement(statement).inspect_err(|_| {
313+
// if we error'd at any point we should put the buffer back to its old length to not
314+
// corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and
315+
// tries with a successful statement later
316+
self.buffer.truncate(initial_len);
317+
})
318+
}
319+
297320
#[allow(clippy::result_large_err)]
298321
pub(crate) fn from_batch(
299322
batch: &Batch,
@@ -403,6 +426,96 @@ impl BoundBatch {
403426
self.batch_type
404427
}
405428

429+
pub fn statements_len(&self) -> u16 {
430+
self.statements_len
431+
}
432+
433+
// **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors
434+
// because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in
435+
// `self` to be modified if an error occured because the caller will not reset them.
436+
fn raw_append_statement<'p, V: SerializeRow>(
437+
&mut self,
438+
statement: impl Into<BoundBatchStatement<'p, V>>,
439+
) -> Result<(), BoundBatchStatementError> {
440+
let mut statement = statement.into();
441+
let mut first_prepared = None;
442+
443+
if self.statements_len == 0 {
444+
// save it into a local variable for now in case a latter steps fails
445+
first_prepared = match statement {
446+
BoundBatchStatement::Bound(ref b) => b
447+
.token()?
448+
.map(|token| (b.prepared.clone().into_owned(), token)),
449+
BoundBatchStatement::Prepared(ps, values) => {
450+
let bound = ps
451+
.into_bind(&values)
452+
.map_err(BatchStatementSerializationError::ValuesSerialiation)?;
453+
let first_prepared = bound
454+
.token()?
455+
.map(|token| (bound.prepared.clone().into_owned(), token));
456+
// we already serialized it so to avoid re-serializing it, modify the statement to a
457+
// BoundStatement
458+
statement = BoundBatchStatement::Bound(bound);
459+
first_prepared
460+
}
461+
BoundBatchStatement::Query(_) => None,
462+
};
463+
}
464+
465+
let stmnt = match &statement {
466+
BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared {
467+
id: Cow::Borrowed(ps.get_id()),
468+
},
469+
BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared {
470+
id: Cow::Borrowed(b.prepared.get_id()),
471+
},
472+
BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query {
473+
text: Cow::Borrowed(&q.contents),
474+
},
475+
};
476+
477+
serialize_statement(stmnt, &mut self.buffer, |writer| match &statement {
478+
BoundBatchStatement::Prepared(ps, values) => {
479+
let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata());
480+
values.serialize(&ctx, writer).map(Some)
481+
}
482+
BoundBatchStatement::Bound(b) => {
483+
writer.append_serialize_row(&b.values);
484+
Ok(Some(()))
485+
}
486+
// query has no values
487+
BoundBatchStatement::Query(_) => Ok(Some(())),
488+
})?;
489+
490+
let new_statements_len = self
491+
.statements_len
492+
.checked_add(1)
493+
.ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?;
494+
495+
/*** at this point nothing else should be fallible as we are going to be modifying
496+
* fields that do not get reset ***/
497+
498+
self.statements_len = new_statements_len;
499+
500+
if let Some(first_prepared) = first_prepared {
501+
self.first_prepared = Some(first_prepared);
502+
}
503+
504+
let prepared = match statement {
505+
BoundBatchStatement::Prepared(ps, _) => Cow::Owned(ps),
506+
BoundBatchStatement::Bound(b) => b.prepared,
507+
BoundBatchStatement::Query(_) => return Ok(()),
508+
};
509+
510+
if !self.prepared.contains_key(prepared.get_id()) {
511+
self.prepared
512+
.insert(prepared.get_id().to_owned(), prepared.into_owned());
513+
}
514+
515+
Ok(())
516+
}
517+
518+
#[allow(clippy::result_large_err)]
406519
fn serialize_from_batch_statement<T>(
407520
&mut self,
408521
statement: &BatchStatement,
@@ -477,3 +590,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ
477590
n_statements: n_statements as usize,
478591
}
479592
}
593+
594+
/// This enum represents a CQL statement, that can be part of batch and its values
595+
#[derive(Clone)]
596+
#[non_exhaustive]
597+
pub enum BoundBatchStatement<'p, V: SerializeRow> {
598+
/// A prepared statement and its not-yet serialized values
599+
Prepared(PreparedStatement, V),
600+
/// A statement whose values have already been bound (and thus serialized)
601+
Bound(BoundStatement<'p>),
602+
/// An unprepared statement with no values
603+
Query(Statement),
604+
}
605+
606+
impl<'p> From<BoundStatement<'p>> for BoundBatchStatement<'p, ()> {
607+
fn from(b: BoundStatement<'p>) -> Self {
608+
BoundBatchStatement::Bound(b)
609+
}
610+
}
611+
612+
impl<V: SerializeRow> From<(PreparedStatement, V)> for BoundBatchStatement<'static, V> {
613+
fn from((p, v): (PreparedStatement, V)) -> Self {
614+
BoundBatchStatement::Prepared(p, v)
615+
}
616+
}
617+
618+
impl From<Statement> for BoundBatchStatement<'static, ()> {
619+
fn from(s: Statement) -> Self {
620+
BoundBatchStatement::Query(s)
621+
}
622+
}
623+
624+
impl From<&str> for BoundBatchStatement<'static, ()> {
625+
fn from(s: &str) -> Self {
626+
BoundBatchStatement::Query(Statement::from(s))
627+
}
628+
}
629+
630+
/// An error type returned when adding a statement to a bounded batch fails
631+
#[non_exhaustive]
632+
#[derive(Error, Debug, Clone)]
633+
pub enum BoundBatchStatementError {
634+
/// Failed to serialize the batch statement
635+
#[error(transparent)]
636+
Statement(#[from] BatchStatementSerializationError),
637+
/// Failed to serialize statement's bound values.
638+
#[error("Failed to calculate partition key")]
639+
PartitionKey(#[from] PartitionKeyError),
640+
/// Too many statements in the batch statement.
641+
#[error("Added statement goes over exceeded max value of 65,535")]
642+
TooManyQueriesInBatchStatement,
643+
}
644+
645+
impl Execute for BoundBatch {
646+
async fn execute(&self, session: &Session) -> Result<QueryResult, ExecutionError> {
647+
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
648+
// If users batch statements by shard, they will be rewarded with full shard awareness
649+
let execution_profile = self
650+
.get_execution_profile_handle()
651+
.unwrap_or_else(|| session.get_default_execution_profile_handle())
652+
.access();
653+
654+
let consistency = self
655+
.config
656+
.consistency
657+
.unwrap_or(execution_profile.consistency);
658+
659+
let serial_consistency = self
660+
.config
661+
.serial_consistency
662+
.unwrap_or(execution_profile.serial_consistency);
663+
664+
let (table, token) = self
665+
.first_prepared
666+
.as_ref()
667+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
668+
.unzip();
669+
670+
let statement_info = RoutingInfo {
671+
consistency,
672+
serial_consistency,
673+
token,
674+
table,
675+
is_confirmed_lwt: false,
676+
};
677+
678+
let span = RequestSpan::new_batch();
679+
680+
let (run_request_result, coordinator): (
681+
RunRequestResult<NonErrorQueryResponse>,
682+
Coordinator,
683+
) = session
684+
.run_request(
685+
statement_info,
686+
&self.config,
687+
execution_profile,
688+
|connection: Arc<Connection>,
689+
consistency: Consistency,
690+
execution_profile: &ExecutionProfileInner| {
691+
let serial_consistency = self
692+
.config
693+
.serial_consistency
694+
.unwrap_or(execution_profile.serial_consistency);
695+
async move {
696+
connection
697+
.batch_with_consistency(self, consistency, serial_consistency)
698+
.await
699+
.and_then(QueryResponse::into_non_error_query_response)
700+
}
701+
},
702+
&span,
703+
)
704+
.instrument(span.span().clone())
705+
.await?;
706+
707+
let result = match run_request_result {
708+
RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(coordinator),
709+
RunRequestResult::Completed(non_error_query_response) => {
710+
let result = non_error_query_response.into_query_result(coordinator)?;
711+
span.record_result_fields(&result);
712+
result
713+
}
714+
};
715+
716+
Ok(result)
717+
}
718+
}

0 commit comments

Comments
 (0)