From 62aa9611b107d0ed73b06b3479c780bbc5adb7e2 Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Thu, 3 Apr 2025 14:48:27 -0700 Subject: [PATCH 1/6] Add an internal only `BoundStatement` this struct keeps track of a PreparedStatement and SerializedValues --- scylla/src/client/pager.rs | 44 ++++++------ scylla/src/client/session.rs | 85 +++++++++++------------- scylla/src/cluster/control_connection.rs | 10 +-- scylla/src/cluster/metadata.rs | 5 +- scylla/src/network/connection.rs | 73 +++++++++----------- scylla/src/statement/bound.rs | 61 +++++++++++++++++ scylla/src/statement/mod.rs | 1 + scylla/src/statement/prepared.rs | 60 ++++++++--------- 8 files changed, 186 insertions(+), 153 deletions(-) create mode 100644 scylla/src/statement/bound.rs diff --git a/scylla/src/client/pager.rs b/scylla/src/client/pager.rs index 514b8a2fe6..5ee0c0d8c7 100644 --- a/scylla/src/client/pager.rs +++ b/scylla/src/client/pager.rs @@ -18,7 +18,6 @@ use scylla_cql::frame::request::query::PagingState; use scylla_cql::frame::response::result::RawMetadataAndRawRows; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::frame::types::SerialConsistency; -use scylla_cql::serialize::row::SerializedValues; use scylla_cql::Consistency; use std::result::Result; use thiserror::Error; @@ -38,7 +37,8 @@ use crate::policies::load_balancing::{self, LoadBalancingPolicy, RoutingInfo}; use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession}; use crate::response::query_result::ColumnSpecs; use crate::response::{NonErrorQueryResponse, QueryResponse}; -use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; +use crate::statement::bound::BoundStatement; +use crate::statement::prepared::PartitionKeyError; use crate::statement::unprepared::Statement; use tracing::{trace, trace_span, warn, Instrument}; use uuid::Uuid; @@ -64,8 +64,7 @@ struct ReceivedPage { } pub(crate) struct PreparedPagerConfig { - pub(crate) prepared: PreparedStatement, - pub(crate) values: SerializedValues, + pub(crate) bound: BoundStatement<'static>, pub(crate) execution_profile: Arc, pub(crate) cluster_state: Arc, #[cfg(feature = "metrics")] @@ -803,26 +802,30 @@ impl QueryPager { let (sender, receiver) = mpsc::channel::>(1); let consistency = config + .bound .prepared .config .consistency .unwrap_or(config.execution_profile.consistency); let serial_consistency = config + .bound .prepared .config .serial_consistency .unwrap_or(config.execution_profile.serial_consistency); - let page_size = config.prepared.get_validated_page_size(); + let page_size = config.bound.prepared.get_validated_page_size(); let load_balancing_policy = Arc::clone( config + .bound .prepared .get_load_balancing_policy() .unwrap_or(&config.execution_profile.load_balancing_policy), ); let retry_session = config + .bound .prepared .get_retry_policy() .map(|rp| &**rp) @@ -831,14 +834,7 @@ impl QueryPager { let parent_span = tracing::Span::current(); let worker_task = async move { - let prepared_ref = &config.prepared; - let values_ref = &config.values; - - let (partition_key, token) = match prepared_ref - .extract_partition_key_and_calculate_token( - prepared_ref.get_partitioner_name(), - values_ref, - ) { + let (partition_key, token) = match config.bound.pk_and_token() { Ok(res) => res.unzip(), Err(err) => { let (proof, _res) = ProvingSender::from(sender) @@ -848,22 +844,22 @@ impl QueryPager { } }; - let table_spec = config.prepared.get_table_spec(); + let table_spec = config.bound.prepared.get_table_spec(); let statement_info = RoutingInfo { consistency, serial_consistency, token, table: table_spec, - is_confirmed_lwt: config.prepared.is_confirmed_lwt(), + is_confirmed_lwt: config.bound.prepared.is_confirmed_lwt(), }; + let statement = &config.bound; let page_query = |connection: Arc, consistency: Consistency, paging_state: PagingState| async move { connection .execute_raw_with_consistency( - prepared_ref, - values_ref, + statement, consistency, serial_consistency, Some(page_size), @@ -872,7 +868,7 @@ impl QueryPager { .await }; - let serialized_values_size = config.values.buffer_size(); + let serialized_values_size = config.bound.values.buffer_size(); let replicas: Option> = if let (Some(table_spec), Some(token)) = @@ -905,14 +901,14 @@ impl QueryPager { sender: sender.into(), page_query, statement_info, - query_is_idempotent: config.prepared.config.is_idempotent, + query_is_idempotent: config.bound.prepared.config.is_idempotent, query_consistency: consistency, load_balancing_policy, retry_session, #[cfg(feature = "metrics")] metrics: config.metrics, paging_state: PagingState::start(), - history_listener: config.prepared.config.history_listener.clone(), + history_listener: config.bound.prepared.config.history_listener.clone(), current_request_id: None, current_attempt_id: None, parent_span, @@ -955,23 +951,21 @@ impl QueryPager { } pub(crate) async fn new_for_connection_execute_iter( - prepared: PreparedStatement, - values: SerializedValues, + bound: BoundStatement<'static>, connection: Arc, consistency: Consistency, serial_consistency: Option, ) -> Result { let (sender, receiver) = mpsc::channel::>(1); - let page_size = prepared.get_validated_page_size(); + let page_size = bound.prepared.get_validated_page_size(); let worker_task = async move { let worker = SingleConnectionPagerWorker { sender: sender.into(), fetcher: |paging_state| { connection.execute_raw_with_consistency( - &prepared, - &values, + &bound, consistency, serial_consistency, Some(page_size), diff --git a/scylla/src/client/session.rs b/scylla/src/client/session.rs index 9cbe7d97d9..bdea56f7d9 100644 --- a/scylla/src/client/session.rs +++ b/scylla/src/client/session.rs @@ -38,6 +38,7 @@ use crate::routing::partitioner::PartitionerName; use crate::routing::{Shard, ShardAwarePortRange}; use crate::statement::batch::batch_values; use crate::statement::batch::{Batch, BatchStatement}; +use crate::statement::bound::BoundStatement; use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; use crate::statement::unprepared::Statement; use crate::statement::{Consistency, PageSize, StatementConfig}; @@ -47,7 +48,7 @@ use futures::future::try_join_all; use itertools::Itertools; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::serialize::batch::BatchValues; -use scylla_cql::serialize::row::{SerializeRow, SerializedValues}; +use scylla_cql::serialize::row::SerializeRow; use std::borrow::Borrow; use std::future::Future; use std::net::{IpAddr, SocketAddr}; @@ -640,7 +641,8 @@ impl Session { prepared: &PreparedStatement, values: impl SerializeRow, ) -> Result { - self.do_execute_unpaged(prepared, values).await + let bound = prepared.bind(&values)?; + self.do_execute_unpaged(&bound).await } /// Executes a prepared statement, restricting results to single page. @@ -705,8 +707,8 @@ impl Session { values: impl SerializeRow, paging_state: PagingState, ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - self.do_execute_single_page(prepared, values, paging_state) - .await + let bound = prepared.bind(&values)?; + self.do_execute_single_page(&bound, paging_state).await } /// Execute a prepared statement with paging.\ @@ -753,7 +755,8 @@ impl Session { prepared: impl Into, values: impl SerializeRow, ) -> Result { - self.do_execute_iter(prepared.into(), values).await + let bound = prepared.into().into_bind(&values)?; + self.do_execute_iter(bound).await } /// Execute a batch statement\ @@ -1085,12 +1088,11 @@ impl Session { .and_then(QueryResponse::into_non_error_query_response) } else { let prepared = connection.prepare(statement).await?; - let serialized = prepared.serialize_values(values_ref)?; - span_ref.record_request_size(serialized.buffer_size()); + let bound = prepared.bind(values_ref)?; + span_ref.record_request_size(bound.values.buffer_size()); connection .execute_raw_with_consistency( - &prepared, - &serialized, + &bound, consistency, serial_consistency, page_size, @@ -1183,11 +1185,9 @@ impl Session { // Making QueryPager::new_for_query work with values is too hard (if even possible) // so instead of sending one prepare to a specific connection on each iterator query, // we fully prepare a statement beforehand. - let prepared = self.prepare_nongeneric(&statement).await?; - let values = prepared.serialize_values(&values)?; + let bound = self.prepare_nongeneric(&statement).await?.into_bind(&values)?; QueryPager::new_for_prepared_statement(PreparedPagerConfig { - prepared, - values, + bound, execution_profile, cluster_state: self.cluster.get_state(), #[cfg(feature = "metrics")] @@ -1356,13 +1356,9 @@ impl Session { async fn do_execute_unpaged( &self, - prepared: &PreparedStatement, - values: impl SerializeRow, + bound: &BoundStatement<'_>, ) -> Result { - let serialized_values = prepared.serialize_values(&values)?; - let (result, paging_state) = self - .execute(prepared, &serialized_values, None, PagingState::start()) - .await?; + let (result, paging_state) = self.execute(bound, None, PagingState::start()).await?; if !paging_state.finished() { error!("Unpaged prepared query returned a non-empty paging state! This is a driver-side or server-side bug."); return Err(ExecutionError::LastAttemptError( @@ -1374,14 +1370,11 @@ impl Session { async fn do_execute_single_page( &self, - prepared: &PreparedStatement, - values: impl SerializeRow, + bound: &BoundStatement<'_>, paging_state: PagingState, ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - let serialized_values = prepared.serialize_values(&values)?; - let page_size = prepared.get_validated_page_size(); - self.execute(prepared, &serialized_values, Some(page_size), paging_state) - .await + let page_size = bound.prepared.get_validated_page_size(); + self.execute(bound, Some(page_size), paging_state).await } /// Sends a prepared request to the database, optionally continuing from a saved point. @@ -1396,44 +1389,45 @@ impl Session { /// should be made. async fn execute( &self, - prepared: &PreparedStatement, - serialized_values: &SerializedValues, + bound: &BoundStatement<'_>, page_size: Option, paging_state: PagingState, ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - let values_ref = &serialized_values; let paging_state_ref = &paging_state; - let (partition_key, token) = prepared - .extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref) + let (partition_key, token) = bound + .pk_and_token() .map_err(PartitionKeyError::into_execution_error)? .unzip(); - let execution_profile = prepared + let execution_profile = bound + .prepared .get_execution_profile_handle() .unwrap_or_else(|| self.get_default_execution_profile_handle()) .access(); - let table_spec = prepared.get_table_spec(); + let table_spec = bound.prepared.get_table_spec(); let statement_info = RoutingInfo { - consistency: prepared + consistency: bound + .prepared .config .consistency .unwrap_or(execution_profile.consistency), - serial_consistency: prepared + serial_consistency: bound + .prepared .config .serial_consistency .unwrap_or(execution_profile.serial_consistency), token, table: table_spec, - is_confirmed_lwt: prepared.is_confirmed_lwt(), + is_confirmed_lwt: bound.prepared.is_confirmed_lwt(), }; let span = RequestSpan::new_prepared( partition_key.as_ref().map(|pk| pk.iter()), token, - serialized_values.buffer_size(), + bound.values.buffer_size(), ); if !span.span().is_disabled() { @@ -1450,20 +1444,20 @@ impl Session { ) = self .run_request( statement_info, - &prepared.config, + &bound.prepared.config, execution_profile, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { - let serial_consistency = prepared + let serial_consistency = bound + .prepared .config .serial_consistency .unwrap_or(execution_profile.serial_consistency); async move { connection .execute_raw_with_consistency( - prepared, - values_ref, + bound, consistency, serial_consistency, page_size, @@ -1496,19 +1490,16 @@ impl Session { async fn do_execute_iter( &self, - prepared: PreparedStatement, - values: impl SerializeRow, + bound: BoundStatement<'static>, ) -> Result { - let serialized_values = prepared.serialize_values(&values)?; - - let execution_profile = prepared + let execution_profile = bound + .prepared .get_execution_profile_handle() .unwrap_or_else(|| self.get_default_execution_profile_handle()) .access(); QueryPager::new_for_prepared_statement(PreparedPagerConfig { - prepared, - values: serialized_values, + bound, execution_profile, cluster_state: self.cluster.get_state(), #[cfg(feature = "metrics")] diff --git a/scylla/src/cluster/control_connection.rs b/scylla/src/cluster/control_connection.rs index b4538d70e6..fd957b4cb7 100644 --- a/scylla/src/cluster/control_connection.rs +++ b/scylla/src/cluster/control_connection.rs @@ -6,11 +6,10 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use scylla_cql::serialize::row::SerializedValues; - use crate::client::pager::QueryPager; use crate::errors::{NextRowError, RequestAttemptError}; use crate::network::Connection; +use crate::statement::bound::BoundStatement; use crate::statement::prepared::PreparedStatement; use crate::statement::Statement; @@ -84,12 +83,9 @@ impl ControlConnection { /// the asynchronous iterator interface. pub(crate) async fn execute_iter( &self, - prepared_statement: PreparedStatement, - values: SerializedValues, + bound: BoundStatement<'static>, ) -> Result { - Arc::clone(&self.conn) - .execute_iter(prepared_statement, values) - .await + Arc::clone(&self.conn).execute_iter(bound).await } } diff --git a/scylla/src/cluster/metadata.rs b/scylla/src/cluster/metadata.rs index 73f2d2d058..bc00425b6d 100644 --- a/scylla/src/cluster/metadata.rs +++ b/scylla/src/cluster/metadata.rs @@ -938,9 +938,8 @@ impl ControlConnection { let mut query = Statement::new(query_str); query.set_page_size(METADATA_QUERY_PAGE_SIZE); - let prepared = conn.prepare(query).await?; - let serialized_values = prepared.serialize_values(&keyspaces)?; - conn.execute_iter(prepared, serialized_values) + let bound = conn.prepare(query).await?.into_bind(&keyspaces)?; + conn.execute_iter(bound) .await .map_err(MetadataFetchErrorKind::NextRowError) } diff --git a/scylla/src/network/connection.rs b/scylla/src/network/connection.rs index 2a396b6360..ee04e23ef9 100644 --- a/scylla/src/network/connection.rs +++ b/scylla/src/network/connection.rs @@ -28,6 +28,7 @@ use crate::response::{ use crate::routing::locator::tablets::{RawTablet, TabletParsingError}; use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError}; use crate::statement::batch::{Batch, BatchStatement}; +use crate::statement::bound::BoundStatement; use crate::statement::prepared::PreparedStatement; use crate::statement::unprepared::Statement; use crate::statement::{Consistency, PageSize}; @@ -880,19 +881,18 @@ impl Connection { } #[cfg(test)] - async fn execute_raw_unpaged( + pub(crate) async fn execute_raw_unpaged( &self, - prepared: &PreparedStatement, - values: SerializedValues, + bound: &BoundStatement<'_>, ) -> Result { // This method is used only for driver internal queries, so no need to consult execution profile here. self.execute_raw_with_consistency( - prepared, - &values, - prepared + bound, + bound + .prepared .config .determine_consistency(self.config.default_consistency), - prepared.config.serial_consistency.flatten(), + bound.prepared.config.serial_consistency.flatten(), None, PagingState::start(), ) @@ -901,8 +901,7 @@ impl Connection { pub(crate) async fn execute_raw_with_consistency( &self, - prepared_statement: &PreparedStatement, - values: &SerializedValues, + bound: &BoundStatement<'_>, consistency: Consistency, serial_consistency: Option, page_size: Option, @@ -914,37 +913,39 @@ impl Connection { .as_ref() .map(|gen| gen.next_timestamp()) }; - let timestamp = prepared_statement + let timestamp = bound + .prepared .get_timestamp() .or_else(get_timestamp_from_gen); let execute_frame = execute::Execute { - id: prepared_statement.get_id().to_owned(), + id: bound.prepared.get_id().to_owned(), parameters: query::QueryParameters { consistency, serial_consistency, - values: Cow::Borrowed(values), + values: Cow::Borrowed(&bound.values), page_size: page_size.map(Into::into), timestamp, - skip_metadata: prepared_statement.get_use_cached_result_metadata(), + skip_metadata: bound.prepared.get_use_cached_result_metadata(), paging_state, }, }; - let cached_metadata = prepared_statement + let cached_metadata = bound + .prepared .get_use_cached_result_metadata() - .then(|| prepared_statement.get_result_metadata()); + .then(|| bound.prepared.get_result_metadata()); let query_response = self .send_request( &execute_frame, true, - prepared_statement.config.tracing, + bound.prepared.config.tracing, cached_metadata, ) .await?; - if let Some(spec) = prepared_statement.get_table_spec() { + if let Some(spec) = bound.prepared.get_table_spec() { if let Err(e) = self .update_tablets_from_response(spec, &query_response) .await @@ -960,18 +961,18 @@ impl Connection { }) => { debug!("Connection::execute: Got DbError::Unprepared - repreparing statement with id {:?}", statement_id); // Repreparation of a statement is needed - self.reprepare(prepared_statement.get_statement(), prepared_statement) + self.reprepare(bound.prepared.get_statement(), &bound.prepared) .await?; let new_response = self .send_request( &execute_frame, true, - prepared_statement.config.tracing, + bound.prepared.config.tracing, cached_metadata, ) .await?; - if let Some(spec) = prepared_statement.get_table_spec() { + if let Some(spec) = bound.prepared.get_table_spec() { if let Err(e) = self.update_tablets_from_response(spec, &new_response).await { tracing::warn!( "Error while parsing tablet info from custom payload: {}", @@ -1006,23 +1007,17 @@ impl Connection { /// the asynchronous iterator interface. pub(crate) async fn execute_iter( self: Arc, - prepared_statement: PreparedStatement, - values: SerializedValues, + bound: BoundStatement<'static>, ) -> Result { - let consistency = prepared_statement + let consistency = bound + .prepared .config .determine_consistency(self.config.default_consistency); - let serial_consistency = prepared_statement.config.serial_consistency.flatten(); + let serial_consistency = bound.prepared.config.serial_consistency.flatten(); - QueryPager::new_for_connection_execute_iter( - prepared_statement, - values, - self, - consistency, - serial_consistency, - ) - .await - .map_err(NextRowError::NextPageError) + QueryPager::new_for_connection_execute_iter(bound, self, consistency, serial_consistency) + .await + .map_err(NextRowError::NextPageError) } pub(crate) async fn batch_with_consistency( @@ -2324,8 +2319,9 @@ mod tests { let prepared = connection.prepare(&insert_query).await.unwrap(); let mut insert_futures = Vec::new(); for v in &values { - let values = prepared.serialize_values(&(*v,)).unwrap(); - let fut = async { connection.execute_raw_unpaged(&prepared, values).await }; + let bound = prepared.bind(&(*v,)).unwrap(); + let connection = &connection; + let fut = async move { connection.execute_raw_unpaged(&bound).await }; insert_futures.push(fut); } @@ -2429,11 +2425,8 @@ mod tests { let conn = conn.clone(); async move { let prepared = conn.prepare(&q).await.unwrap(); - let values = prepared - .serialize_values(&(j, vec![j as u8; j as usize])) - .unwrap(); - let response = - conn.execute_raw_unpaged(&prepared, values).await.unwrap(); + let bound = prepared.bind(&(j, vec![j as u8; j as usize])).unwrap(); + let response = conn.execute_raw_unpaged(&bound).await.unwrap(); // QueryResponse might contain an error - make sure that there were no errors let _nonerror_response = response.into_non_error_query_response().unwrap(); diff --git a/scylla/src/statement/bound.rs b/scylla/src/statement/bound.rs new file mode 100644 index 0000000000..3d74ec9ba0 --- /dev/null +++ b/scylla/src/statement/bound.rs @@ -0,0 +1,61 @@ +use std::borrow::Cow; + +use scylla_cql::serialize::{ + row::{SerializeRow, SerializedValues}, + SerializationError, +}; + +use crate::routing::Token; + +use super::prepared::{ + PartitionKey, PartitionKeyError, PartitionKeyExtractionError, PreparedStatement, +}; + +/// Represents a statement that already had all its values bound +#[derive(Debug, Clone)] +pub struct BoundStatement<'p> { + pub(crate) prepared: Cow<'p, PreparedStatement>, + pub(crate) values: SerializedValues, +} + +impl<'p> BoundStatement<'p> { + pub(crate) fn new( + prepared: Cow<'p, PreparedStatement>, + values: &impl SerializeRow, + ) -> Result { + let values = prepared.serialize_values(values)?; + Ok(Self { prepared, values }) + } + + /// Determines which values constitute the partition key and puts them in order. + /// + /// This is a preparation step necessary for calculating token based on a prepared statement. + pub(crate) fn pk(&self) -> Result, PartitionKeyExtractionError> { + PartitionKey::new(self.prepared.get_prepared_metadata(), &self.values) + } + + pub(crate) fn pk_and_token( + &self, + ) -> Result, Token)>, PartitionKeyError> { + if !self.prepared.is_token_aware() { + return Ok(None); + } + + let partition_key = self.pk()?; + let token = partition_key.calculate_token(self.prepared.get_partitioner_name())?; + Ok(Some((partition_key, token))) + } + + /// Calculates the token for the prepared statement and its bound values + /// + /// Returns the token that would be computed for executing the provided prepared statement with + /// the provided values. + pub fn token(&self) -> Result, PartitionKeyError> { + self.pk_and_token().map(|p| p.map(|(_, t)| t)) + } + + /// Returns the prepared statement behind the `BoundStatement` + pub fn prepared(&self) -> &PreparedStatement { + &self.prepared + } +} diff --git a/scylla/src/statement/mod.rs b/scylla/src/statement/mod.rs index e98614ec4d..a03c400f4d 100644 --- a/scylla/src/statement/mod.rs +++ b/scylla/src/statement/mod.rs @@ -15,6 +15,7 @@ use crate::policies::load_balancing::LoadBalancingPolicy; use crate::policies::retry::RetryPolicy; pub mod batch; +pub mod bound; pub mod prepared; pub mod unprepared; diff --git a/scylla/src/statement/prepared.rs b/scylla/src/statement/prepared.rs index 9c02902ad4..d075a74298 100644 --- a/scylla/src/statement/prepared.rs +++ b/scylla/src/statement/prepared.rs @@ -6,12 +6,14 @@ use scylla_cql::frame::types::RawValue; use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues}; use scylla_cql::serialize::SerializationError; use smallvec::{smallvec, SmallVec}; +use std::borrow::Cow; use std::convert::TryInto; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use uuid::Uuid; +use super::bound::BoundStatement; use super::{PageSize, StatementConfig}; use crate::client::execution_profile::ExecutionProfileHandle; use crate::errors::{BadQuery, ExecutionError}; @@ -210,8 +212,8 @@ impl PreparedStatement { &self, bound_values: &impl SerializeRow, ) -> Result { - let serialized = self.serialize_values(bound_values)?; - let partition_key = self.extract_partition_key(&serialized)?; + let bound = self.bind(bound_values)?; + let partition_key = bound.pk()?; let mut buf = BytesMut::new(); let mut writer = |chunk: &[u8]| buf.extend_from_slice(chunk); @@ -220,31 +222,6 @@ impl PreparedStatement { Ok(buf.freeze()) } - /// Determines which values constitute the partition key and puts them in order. - /// - /// This is a preparation step necessary for calculating token based on a prepared statement. - pub(crate) fn extract_partition_key<'ps>( - &'ps self, - bound_values: &'ps SerializedValues, - ) -> Result, PartitionKeyExtractionError> { - PartitionKey::new(self.get_prepared_metadata(), bound_values) - } - - pub(crate) fn extract_partition_key_and_calculate_token<'ps>( - &'ps self, - partitioner_name: &'ps PartitionerName, - serialized_values: &'ps SerializedValues, - ) -> Result, Token)>, PartitionKeyError> { - if !self.is_token_aware() { - return Ok(None); - } - - let partition_key = self.extract_partition_key(serialized_values)?; - let token = partition_key.calculate_token(partitioner_name)?; - - Ok(Some((partition_key, token))) - } - /// Calculates the token for given prepared statement and values. /// /// Returns the token that would be computed for executing the provided @@ -256,7 +233,8 @@ impl PreparedStatement { &self, values: &impl SerializeRow, ) -> Result, PartitionKeyError> { - self.calculate_token_untyped(&self.serialize_values(values)?) + let bound = self.bind(values)?; + bound.token() } // A version of calculate_token which skips serialization and uses SerializedValues directly. @@ -265,8 +243,14 @@ impl PreparedStatement { &self, values: &SerializedValues, ) -> Result, PartitionKeyError> { - self.extract_partition_key_and_calculate_token(&self.partitioner_name, values) - .map(|opt| opt.map(|(_pk, token)| token)) + if !self.is_token_aware() { + return Ok(None); + } + + let partition_key = PartitionKey::new(self.get_prepared_metadata(), values)?; + let token = partition_key.calculate_token(&self.partitioner_name)?; + + Ok(Some(token)) } /// Return keyspace name and table name this statement is operating on. @@ -494,6 +478,20 @@ impl PreparedStatement { self.config.execution_profile_handle.as_ref() } + pub(crate) fn bind( + &self, + values: &impl SerializeRow, + ) -> Result, SerializationError> { + BoundStatement::new(Cow::Borrowed(self), values) + } + + pub(crate) fn into_bind( + self, + values: &impl SerializeRow, + ) -> Result, SerializationError> { + BoundStatement::new(Cow::Owned(self), values) + } + pub(crate) fn serialize_values( &self, values: &impl SerializeRow, @@ -562,7 +560,7 @@ pub(crate) struct PartitionKey<'ps> { impl<'ps> PartitionKey<'ps> { const SMALLVEC_ON_STACK_SIZE: usize = 8; - fn new( + pub(crate) fn new( prepared_metadata: &'ps PreparedMetadata, bound_values: &'ps SerializedValues, ) -> Result { From dd7d532926108cb0da3fbf2ffb7c57be1967709d Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Wed, 9 Apr 2025 15:39:03 -0700 Subject: [PATCH 2/6] Add an internal only `BoundBatch` --- scylla-cql/src/frame/request/batch.rs | 48 ++++- scylla/src/client/session.rs | 112 ++++++---- scylla/src/network/connection.rs | 85 +------- scylla/src/statement/batch.rs | 298 +++++++++++++++----------- scylla/src/statement/bound.rs | 7 + scylla/src/statement/prepared.rs | 16 -- 6 files changed, 311 insertions(+), 255 deletions(-) diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 32788a341b..88556ac6e3 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -35,6 +35,52 @@ where pub values: Values, } +#[cfg_attr(test, derive(Debug, PartialEq, Eq))] +pub struct BatchV2<'b> { + pub statements_and_values: Cow<'b, [u8]>, + pub batch_type: BatchType, + pub consistency: types::Consistency, + pub serial_consistency: Option, + pub timestamp: Option, + pub statements_len: u16, +} + +impl SerializableRequest for BatchV2<'_> { + const OPCODE: RequestOpcode = RequestOpcode::Batch; + + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + // Serializing type of batch + buf.put_u8(self.batch_type as u8); + + // Serializing queries + types::write_short(self.statements_len, buf); + buf.extend_from_slice(&self.statements_and_values); + + // Serializing consistency + types::write_consistency(self.consistency, buf); + + // Serializing flags + let mut flags = 0; + if self.serial_consistency.is_some() { + flags |= FLAG_WITH_SERIAL_CONSISTENCY; + } + if self.timestamp.is_some() { + flags |= FLAG_WITH_DEFAULT_TIMESTAMP; + } + + buf.put_u8(flags); + + if let Some(serial_consistency) = self.serial_consistency { + types::write_serial_consistency(serial_consistency, buf); + } + if let Some(timestamp) = self.timestamp { + types::write_long(timestamp, buf); + } + + Ok(()) + } +} + /// The type of a batch. #[derive(Clone, Copy)] #[cfg_attr(test, derive(Debug, PartialEq, Eq))] @@ -208,7 +254,7 @@ impl BatchStatement<'_> { } impl BatchStatement<'_> { - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> { + pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> { match self { Self::Query { text } => { buf.put_u8(0); diff --git a/scylla/src/client/session.rs b/scylla/src/client/session.rs index bdea56f7d9..1d1c5a90d1 100644 --- a/scylla/src/client/session.rs +++ b/scylla/src/client/session.rs @@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint; use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef}; use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState}; use crate::errors::{ - BadQuery, BrokenConnectionError, ExecutionError, MetadataError, NewSessionError, - PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, - TracingError, UseKeyspaceError, + BrokenConnectionError, ExecutionError, MetadataError, NewSessionError, PagerExecutionError, + PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, TracingError, + UseKeyspaceError, }; use crate::frame::response::result; use crate::network::tls::TlsProvider; @@ -36,7 +36,7 @@ use crate::response::{ }; use crate::routing::partitioner::PartitionerName; use crate::routing::{Shard, ShardAwarePortRange}; -use crate::statement::batch::batch_values; +use crate::statement::batch::BoundBatch; use crate::statement::batch::{Batch, BatchStatement}; use crate::statement::bound::BoundStatement; use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; @@ -47,9 +47,10 @@ use futures::future::join_all; use futures::future::try_join_all; use itertools::Itertools; use scylla_cql::frame::response::NonErrorResponse; -use scylla_cql::serialize::batch::BatchValues; +use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator}; use scylla_cql::serialize::row::SerializeRow; -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; +use std::collections::{HashMap, HashSet}; use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::num::NonZeroU32; @@ -809,7 +810,10 @@ impl Session { batch: &Batch, values: impl BatchValues, ) -> Result { - self.do_batch(batch, values).await + let batch = self.last_minute_prepare_batch(batch, &values).await?; + let batch = BoundBatch::from_batch(batch.as_ref(), values)?; + + self.do_batch(&batch).await } /// Estabilishes a CQL session with the database @@ -1185,7 +1189,10 @@ impl Session { // Making QueryPager::new_for_query work with values is too hard (if even possible) // so instead of sending one prepare to a specific connection on each iterator query, // we fully prepare a statement beforehand. - let bound = self.prepare_nongeneric(&statement).await?.into_bind(&values)?; + let bound = self + .prepare_nongeneric(&statement) + .await? + .into_bind(&values)?; QueryPager::new_for_prepared_statement(PreparedPagerConfig { bound, execution_profile, @@ -1509,22 +1516,9 @@ impl Session { .map_err(PagerExecutionError::NextPageError) } - async fn do_batch( - &self, - batch: &Batch, - values: impl BatchValues, - ) -> Result { + async fn do_batch(&self, batch: &BoundBatch) -> Result { // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard // If users batch statements by shard, they will be rewarded with full shard awareness - - // check to ensure that we don't send a batch statement with more than u16::MAX queries - let batch_statements_length = batch.statements.len(); - if batch_statements_length > u16::MAX as usize { - return Err(ExecutionError::BadQuery( - BadQuery::TooManyQueriesInBatchStatement(batch_statements_length), - )); - } - let execution_profile = batch .get_execution_profile_handle() .unwrap_or_else(|| self.get_default_execution_profile_handle()) @@ -1540,22 +1534,17 @@ impl Session { .serial_consistency .unwrap_or(execution_profile.serial_consistency); - let (first_value_token, values) = - batch_values::peek_first_token(values, batch.statements.first())?; - let values_ref = &values; - - let table_spec = - if let Some(BatchStatement::PreparedStatement(ps)) = batch.statements.first() { - ps.get_table_spec() - } else { - None - }; + let (table, token) = batch + .first_prepared + .as_ref() + .and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token))) + .unzip(); let statement_info = RoutingInfo { consistency, serial_consistency, - token: first_value_token, - table: table_spec, + token, + table, is_confirmed_lwt: false, }; @@ -1578,12 +1567,7 @@ impl Session { .unwrap_or(execution_profile.serial_consistency); async move { connection - .batch_with_consistency( - batch, - values_ref, - consistency, - serial_consistency, - ) + .batch_with_consistency(batch, consistency, serial_consistency) .await .and_then(QueryResponse::into_non_error_query_response) } @@ -1652,6 +1636,54 @@ impl Session { Ok(prepared_batch) } + async fn last_minute_prepare_batch<'b>( + &self, + init_batch: &'b Batch, + values: impl BatchValues, + ) -> Result, PrepareError> { + let mut to_prepare = HashSet::<&str>::new(); + + { + let mut values_iter = values.batch_values_iter(); + for stmt in &init_batch.statements { + if let BatchStatement::Query(query) = stmt { + if let Some(false) = values_iter.is_empty_next() { + to_prepare.insert(&query.contents); + } + } else { + values_iter.skip_next(); + } + } + } + + if to_prepare.is_empty() { + return Ok(Cow::Borrowed(init_batch)); + } + + let mut prepared_queries = HashMap::<&str, PreparedStatement>::new(); + + for query in to_prepare { + let prepared = self.prepare(query).await?; + prepared_queries.insert(query, prepared); + } + + let mut batch: Cow = Cow::Owned(Batch::new_from(init_batch)); + for stmt in &init_batch.statements { + match stmt { + BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str()) + { + Some(prepared) => batch.to_mut().append_statement(prepared.clone()), + None => batch.to_mut().append_statement(query.clone()), + }, + BatchStatement::PreparedStatement(prepared) => { + batch.to_mut().append_statement(prepared.clone()); + } + } + } + + Ok(batch) + } + /// Sends `USE ` request on all connections\ /// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\ /// diff --git a/scylla/src/network/connection.rs b/scylla/src/network/connection.rs index ee04e23ef9..8043fa6a81 100644 --- a/scylla/src/network/connection.rs +++ b/scylla/src/network/connection.rs @@ -27,7 +27,7 @@ use crate::response::{ }; use crate::routing::locator::tablets::{RawTablet, TabletParsingError}; use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError}; -use crate::statement::batch::{Batch, BatchStatement}; +use crate::statement::batch::BoundBatch; use crate::statement::bound::BoundStatement; use crate::statement::prepared::PreparedStatement; use crate::statement::unprepared::Statement; @@ -42,12 +42,10 @@ use scylla_cql::frame::response::result::{ResultMetadata, TableSpec}; use scylla_cql::frame::response::Error; use scylla_cql::frame::response::{self, error}; use scylla_cql::frame::types::SerialConsistency; -use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator}; -use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter; -use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues}; +use scylla_cql::serialize::row::SerializedValues; use socket2::{SockRef, TcpKeepalive}; use std::borrow::Cow; -use std::collections::{BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap}; use std::convert::TryFrom; use std::net::{IpAddr, SocketAddr}; use std::num::NonZeroU64; @@ -1022,22 +1020,10 @@ impl Connection { pub(crate) async fn batch_with_consistency( &self, - init_batch: &Batch, - values: impl BatchValues, + batch: &BoundBatch, consistency: Consistency, serial_consistency: Option, ) -> Result { - let batch = self.prepare_batch(init_batch, &values).await?; - - let contexts = batch.statements.iter().map(|bs| match bs { - BatchStatement::Query(_) => RowSerializationContext::empty(), - BatchStatement::PreparedStatement(ps) => { - RowSerializationContext::from_prepared(ps.get_prepared_metadata()) - } - }); - - let values = RawBatchValuesAdapter::new(values, contexts); - let get_timestamp_from_gen = || { self.config .timestamp_generator @@ -1046,13 +1032,13 @@ impl Connection { }; let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen); - let batch_frame = batch::Batch { - statements: Cow::Borrowed(&batch.statements), - values, + let batch_frame = batch::BatchV2 { + statements_and_values: Cow::Borrowed(&batch.buffer), batch_type: batch.get_type(), consistency, serial_consistency, timestamp, + statements_len: batch.statements_len, }; loop { @@ -1065,13 +1051,8 @@ impl Connection { Response::Error(err) => match err.error { DbError::Unprepared { statement_id } => { debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id); - let prepared_statement = batch.statements.iter().find_map(|s| match s { - BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => { - Some(s) - } - _ => None, - }); - if let Some(p) = prepared_statement { + + if let Some(p) = batch.prepared.get(&statement_id) { self.reprepare(p.get_statement(), p).await?; continue; } else { @@ -1088,54 +1069,6 @@ impl Connection { } } - async fn prepare_batch<'b>( - &self, - init_batch: &'b Batch, - values: impl BatchValues, - ) -> Result, RequestAttemptError> { - let mut to_prepare = HashSet::<&str>::new(); - - { - let mut values_iter = values.batch_values_iter(); - for stmt in &init_batch.statements { - if let BatchStatement::Query(query) = stmt { - if let Some(false) = values_iter.is_empty_next() { - to_prepare.insert(&query.contents); - } - } else { - values_iter.skip_next(); - } - } - } - - if to_prepare.is_empty() { - return Ok(Cow::Borrowed(init_batch)); - } - - let mut prepared_queries = HashMap::<&str, PreparedStatement>::new(); - - for query in &to_prepare { - let prepared = self.prepare(&Statement::new(query.to_string())).await?; - prepared_queries.insert(query, prepared); - } - - let mut batch: Cow = Cow::Owned(Batch::new_from(init_batch)); - for stmt in &init_batch.statements { - match stmt { - BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str()) - { - Some(prepared) => batch.to_mut().append_statement(prepared.clone()), - None => batch.to_mut().append_statement(query.clone()), - }, - BatchStatement::PreparedStatement(prepared) => { - batch.to_mut().append_statement(prepared.clone()); - } - } - } - - Ok(batch) - } - pub(super) async fn use_keyspace( &self, keyspace_name: &VerifiedKeyspaceName, diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index ed7e39e93d..99a0e43821 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -1,14 +1,27 @@ use std::borrow::Cow; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use bytes::Bytes; +use scylla_cql::frame::frame_errors::{ + BatchSerializationError, BatchStatementSerializationError, CqlRequestSerializationError, +}; +use scylla_cql::frame::request; +use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator}; +use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues}; +use scylla_cql::serialize::{RowWriter, SerializationError}; + use crate::client::execution_profile::ExecutionProfileHandle; +use crate::errors::{BadQuery, ExecutionError, RequestAttemptError}; use crate::observability::history::HistoryListener; use crate::policies::load_balancing::LoadBalancingPolicy; use crate::policies::retry::RetryPolicy; -use crate::statement::prepared::PreparedStatement; +use crate::routing::Token; +use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; use crate::statement::unprepared::Statement; +use super::bound::BoundStatement; use super::StatementConfig; use super::{Consistency, SerialConsistency}; pub use crate::frame::request::batch::BatchType; @@ -263,145 +276,186 @@ impl<'a: 'b, 'b> From<&'a BatchStatement> } } -pub(crate) mod batch_values { - use scylla_cql::serialize::batch::BatchValues; - use scylla_cql::serialize::batch::BatchValuesIterator; - use scylla_cql::serialize::row::RowSerializationContext; - use scylla_cql::serialize::row::SerializedValues; - use scylla_cql::serialize::{RowWriter, SerializationError}; +/// A batch with all of its statements bound to values +pub(crate) struct BoundBatch { + pub(crate) config: StatementConfig, + batch_type: BatchType, + pub(crate) buffer: Vec, + pub(crate) prepared: HashMap, + pub(crate) first_prepared: Option<(PreparedStatement, Token)>, + pub(crate) statements_len: u16, +} - use crate::errors::ExecutionError; - use crate::routing::Token; - use crate::statement::prepared::PartitionKeyError; +impl BoundBatch { + #[allow(clippy::result_large_err)] + pub(crate) fn from_batch( + batch: &Batch, + values: impl BatchValues, + ) -> Result { + let mut bound_batch = BoundBatch { + config: batch.config.clone(), + batch_type: batch.batch_type, + prepared: HashMap::new(), + buffer: vec![], + first_prepared: None, + statements_len: batch.statements.len().try_into().map_err(|_| { + ExecutionError::BadQuery(BadQuery::TooManyQueriesInBatchStatement( + batch.statements.len(), + )) + })?, + }; - use super::BatchStatement; + let mut values = values.batch_values_iter(); + let mut statements = batch.statements.iter().enumerate(); - /// Takes an optional reference to the first statement in the batch and - /// the batch values, and tries to compute the token for the statement. - /// Returns the (optional) token and batch values. If the function needed - /// to serialize values for the first statement, the returned batch values - /// will cache the results of the serialization. - /// - /// NOTE: Batch values returned by this function might not type check - /// the first statement when it is serialized! However, if they don't, - /// then the first row was already checked by the function. It is assumed - /// that `statement` holds the first prepared statement of the batch (if - /// there is one), and that it will be used later to serialize the values. - #[allow(clippy::result_large_err)] - pub(crate) fn peek_first_token<'bv>( - values: impl BatchValues + 'bv, - statement: Option<&BatchStatement>, - ) -> Result<(Option, impl BatchValues + 'bv), ExecutionError> { - let mut values_iter = values.batch_values_iter(); - let (token, first_values) = match statement { - Some(BatchStatement::PreparedStatement(ps)) => { - let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata()); - let (first_values, did_write) = SerializedValues::from_closure(|writer| { - values_iter - .serialize_next(&ctx, writer) - .transpose() - .map(|o| o.is_some()) - })?; - if did_write { - let token = ps - .calculate_token_untyped(&first_values) + if let Some((idx, statement)) = statements.next() { + match statement { + BatchStatement::Query(_) => { + bound_batch.serialize_from_batch_statement(statement, idx, |writer| { + let ctx = RowSerializationContext::empty(); + values.serialize_next(&ctx, writer).transpose() + })?; + } + BatchStatement::PreparedStatement(ps) => { + let values = + bound_batch.serialize_from_batch_statement(statement, idx, |writer| { + let ctx = + RowSerializationContext::from_prepared(ps.get_prepared_metadata()); + + let values = SerializedValues::from_closure(|writer| { + values.serialize_next(&ctx, writer).transpose() + }) + .map(|(values, opt)| opt.map(|_| values)); + + if let Ok(Some(values)) = &values { + writer.append_serialize_row(values); + } + + values + })?; + + let bound = BoundStatement::new_untyped(Cow::Borrowed(ps), values); + let token = bound + .token() .map_err(PartitionKeyError::into_execution_error)?; - (token, Some(first_values)) - } else { - (None, None) + + let prepared = bound.prepared.into_owned(); + bound_batch.first_prepared = token.map(|token| (prepared.clone(), token)); + bound_batch + .prepared + .insert(prepared.get_id().to_owned(), prepared); } } - _ => (None, None), - }; + } - // Need to do it explicitly, otherwise the next line will complain - // that `values_iter` still borrows `values`. - std::mem::drop(values_iter); + for (idx, statement) in statements { + bound_batch.serialize_from_batch_statement(statement, idx, |writer| { + let ctx = match statement { + BatchStatement::Query(_) => RowSerializationContext::empty(), + BatchStatement::PreparedStatement(ps) => { + RowSerializationContext::from_prepared(ps.get_prepared_metadata()) + } + }; + values.serialize_next(&ctx, writer).transpose() + })?; + + if let BatchStatement::PreparedStatement(ps) = statement { + if !bound_batch.prepared.contains_key(ps.get_id()) { + bound_batch + .prepared + .insert(ps.get_id().to_owned(), ps.clone()); + } + } + } - // Reuse the already serialized first value via `BatchValuesFirstSerialized`. - let values = BatchValuesFirstSerialized::new(values, first_values); + // At this point, we have all statements serialized. If any values are still left, we have a mismatch. + if values.skip_next().is_some() { + return Err(ExecutionError::LastAttemptError( + RequestAttemptError::CqlRequestSerialization( + CqlRequestSerializationError::BatchSerialization(counts_mismatch_err( + bound_batch.statements_len as usize + 1 /*skipped above*/ + values.count(), + bound_batch.statements_len, + )), + ), + )); + } - Ok((token, values)) + Ok(bound_batch) } - struct BatchValuesFirstSerialized { - // Contains the first value of BV in a serialized form. - // The first value in the iterator returned from `rest` should be skipped! - first: Option, - rest: BV, + /// Borrows the execution profile handle associated with this batch. + pub(crate) fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { + self.config.execution_profile_handle.as_ref() } - impl BatchValuesFirstSerialized { - fn new(rest: BV, first: Option) -> Self { - Self { first, rest } - } + /// Gets the default timestamp for this batch in microseconds. + pub(crate) fn get_timestamp(&self) -> Option { + self.config.timestamp } - impl BatchValues for BatchValuesFirstSerialized - where - BV: BatchValues, - { - type BatchValuesIter<'r> - = BatchValuesFirstSerializedIterator<'r, BV::BatchValuesIter<'r>> - where - Self: 'r; - - fn batch_values_iter(&self) -> Self::BatchValuesIter<'_> { - BatchValuesFirstSerializedIterator { - first: self.first.as_ref(), - rest: self.rest.batch_values_iter(), - } - } + /// Gets type of batch. + pub(crate) fn get_type(&self) -> BatchType { + self.batch_type } - struct BatchValuesFirstSerializedIterator<'f, BVI> { - first: Option<&'f SerializedValues>, - rest: BVI, - } - - impl<'f, BVI> BatchValuesIterator<'f> for BatchValuesFirstSerializedIterator<'f, BVI> - where - BVI: BatchValuesIterator<'f>, - { - #[inline] - fn serialize_next( - &mut self, - ctx: &RowSerializationContext<'_>, - writer: &mut RowWriter, - ) -> Option> { - match self.first.take() { - Some(sr) => { - writer.append_serialize_row(sr); - self.rest.skip_next(); - Some(Ok(())) - } - None => self.rest.serialize_next(ctx, writer), - } - } - - #[inline] - fn is_empty_next(&mut self) -> Option { - match self.first.take() { - Some(s) => { - self.rest.skip_next(); - Some(s.is_empty()) - } - None => self.rest.is_empty_next(), - } - } + fn serialize_from_batch_statement( + &mut self, + statement: &BatchStatement, + statement_idx: usize, + serialize: impl FnOnce(&mut RowWriter<'_>) -> Result, SerializationError>, + ) -> Result { + serialize_statement( + request::batch::BatchStatement::from(statement), + &mut self.buffer, + serialize, + ) + .map_err(|error| BatchSerializationError::StatementSerialization { + statement_idx, + error, + }) + .transpose() + .unwrap_or_else(|| Err(counts_mismatch_err(statement_idx, self.statements_len))) + .map_err(|e| { + ExecutionError::LastAttemptError(RequestAttemptError::CqlRequestSerialization( + CqlRequestSerializationError::BatchSerialization(e), + )) + }) + } +} - #[inline] - fn skip_next(&mut self) -> Option<()> { - self.first = None; - self.rest.skip_next() - } +fn serialize_statement( + statement: request::batch::BatchStatement, + buffer: &mut Vec, + serialize: impl FnOnce(&mut RowWriter<'_>) -> Result, SerializationError>, +) -> Result, BatchStatementSerializationError> { + statement.serialize(buffer)?; + + // Reserve two bytes for length + let length_pos = buffer.len(); + buffer.extend_from_slice(&[0, 0]); + + // serialize the values + let mut writer = RowWriter::new(buffer); + let Some(res) = + serialize(&mut writer).map_err(BatchStatementSerializationError::ValuesSerialiation)? + else { + return Ok(None); + }; + + // Go back and put the length + let count: u16 = writer + .value_count() + .try_into() + .map_err(|_| BatchStatementSerializationError::TooManyValues(writer.value_count()))?; + + buffer[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes()); + + Ok(Some(res)) +} - #[inline] - fn count(self) -> usize - where - Self: Sized, - { - self.rest.count() - } +fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializationError { + BatchSerializationError::ValuesAndStatementsLengthMismatch { + n_value_lists, + n_statements: n_statements as usize, } } diff --git a/scylla/src/statement/bound.rs b/scylla/src/statement/bound.rs index 3d74ec9ba0..42a8b1e4ec 100644 --- a/scylla/src/statement/bound.rs +++ b/scylla/src/statement/bound.rs @@ -27,6 +27,13 @@ impl<'p> BoundStatement<'p> { Ok(Self { prepared, values }) } + pub(crate) fn new_untyped( + prepared: Cow<'p, PreparedStatement>, + values: SerializedValues, + ) -> Self { + Self { prepared, values } + } + /// Determines which values constitute the partition key and puts them in order. /// /// This is a preparation step necessary for calculating token based on a prepared statement. diff --git a/scylla/src/statement/prepared.rs b/scylla/src/statement/prepared.rs index d075a74298..0611c8fdb2 100644 --- a/scylla/src/statement/prepared.rs +++ b/scylla/src/statement/prepared.rs @@ -237,22 +237,6 @@ impl PreparedStatement { bound.token() } - // A version of calculate_token which skips serialization and uses SerializedValues directly. - // Not type-safe, so not exposed to users. - pub(crate) fn calculate_token_untyped( - &self, - values: &SerializedValues, - ) -> Result, PartitionKeyError> { - if !self.is_token_aware() { - return Ok(None); - } - - let partition_key = PartitionKey::new(self.get_prepared_metadata(), values)?; - let token = partition_key.calculate_token(&self.partitioner_name)?; - - Ok(Some(token)) - } - /// Return keyspace name and table name this statement is operating on. pub fn get_table_spec(&self) -> Option<&TableSpec> { self.get_prepared_metadata() From 80308b5e02c3365996ad26a8ca76ff07ac3b5211 Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Thu, 10 Apr 2025 18:02:41 -0700 Subject: [PATCH 3/6] switch scylla-cql's request::Batch to use new BatchV2 version this is the version that the top crate (scylla) will use to send batches --- scylla-cql/src/frame/request/batch.rs | 82 +++++++++++++++++++++++++++ scylla-cql/src/frame/request/mod.rs | 73 +++++++++++++++--------- scylla-cql/src/frame/types.rs | 28 +++++++++ scylla/src/statement/prepared.rs | 10 +++- 4 files changed, 163 insertions(+), 30 deletions(-) diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 88556ac6e3..c4add6433c 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -81,6 +81,65 @@ impl SerializableRequest for BatchV2<'_> { } } +impl DeserializableRequest for BatchV2<'static> { + fn deserialize(buf: &mut &[u8]) -> Result { + let batch_type = buf.get_u8().try_into()?; + let statements_len = types::read_short(buf)?; + + let statements_and_values = (0..statements_len).try_fold( + // technically allocating 3-13 bytes too many but that's OK + Vec::with_capacity(buf.len()), + |mut statements_and_values, _| { + BatchStatement::deserialize_to_buffer(buf, &mut statements_and_values)?; + // As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used. + let values = SerializedValues::new_from_frame(buf)?; + statements_and_values.extend_from_slice(&values.element_count().to_be_bytes()); + statements_and_values.extend_from_slice(values.get_contents()); + + Result::<_, RequestDeserializationError>::Ok(statements_and_values) + }, + )?; + + let consistency = types::read_consistency(buf)?; + + let flags = buf.get_u8(); + let unknown_flags = flags & (!ALL_FLAGS); + if unknown_flags != 0 { + return Err(RequestDeserializationError::UnknownFlags { + flags: unknown_flags, + }); + } + let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0; + let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0; + + let serial_consistency = serial_consistency_flag + .then(|| types::read_consistency(buf)) + .transpose()? + .map( + |consistency| match SerialConsistency::try_from(consistency) { + Ok(serial_consistency) => Ok(serial_consistency), + Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency( + consistency, + )), + }, + ) + .transpose()?; + + let timestamp = default_timestamp_flag + .then(|| types::read_long(buf)) + .transpose()?; + + Ok(Self { + batch_type, + consistency, + serial_consistency, + timestamp, + statements_len, + statements_and_values: Cow::Owned(statements_and_values), + }) + } +} + /// The type of a batch. #[derive(Clone, Copy)] #[cfg_attr(test, derive(Debug, PartialEq, Eq))] @@ -251,6 +310,29 @@ impl BatchStatement<'_> { )), } } + + fn deserialize_to_buffer( + input: &mut &[u8], + out: &mut Vec, + ) -> Result<(), RequestDeserializationError> { + match input.get_u8() { + 0 => { + out.put_u8(0); + types::read_long_string_to_buff(input, out)?; + + Ok(()) + } + 1 => { + out.put_u8(1); + types::read_short_bytes_to_buffer(input, out)?; + + Ok(()) + } + kind => Err(RequestDeserializationError::UnexpectedBatchStatementKind( + kind, + )), + } + } } impl BatchStatement<'_> { diff --git a/scylla-cql/src/frame/request/mod.rs b/scylla-cql/src/frame/request/mod.rs index 1a5256ca33..0bba3284db 100644 --- a/scylla-cql/src/frame/request/mod.rs +++ b/scylla-cql/src/frame/request/mod.rs @@ -16,6 +16,7 @@ use bytes::Bytes; pub use auth_response::AuthResponse; pub use batch::Batch; +pub use batch::BatchV2; pub use execute::Execute; pub use options::Options; pub use prepare::Prepare; @@ -138,6 +139,7 @@ pub enum Request<'r> { Query(Query<'r>), Execute(Execute<'r>), Batch(Batch<'r, BatchStatement<'r>, Vec>), + BatchV2(BatchV2<'r>), } impl Request<'_> { @@ -148,7 +150,7 @@ impl Request<'_> { match opcode { RequestOpcode::Query => Query::deserialize(buf).map(Self::Query), RequestOpcode::Execute => Execute::deserialize(buf).map(Self::Execute), - RequestOpcode::Batch => Batch::deserialize(buf).map(Self::Batch), + RequestOpcode::Batch => BatchV2::deserialize(buf).map(Self::BatchV2), _ => unimplemented!( "Deserialization of opcode {:?} is not yet supported", opcode @@ -162,6 +164,7 @@ impl Request<'_> { Request::Query(q) => Some(q.parameters.consistency), Request::Execute(e) => Some(e.parameters.consistency), Request::Batch(b) => Some(b.consistency), + Request::BatchV2(b) => Some(b.consistency), #[expect(unreachable_patterns)] // until other opcodes are supported _ => None, } @@ -173,6 +176,7 @@ impl Request<'_> { Request::Query(q) => Some(q.parameters.serial_consistency), Request::Execute(e) => Some(e.parameters.serial_consistency), Request::Batch(b) => Some(b.serial_consistency), + Request::BatchV2(b) => Some(b.serial_consistency), #[expect(unreachable_patterns)] // until other opcodes are supported _ => None, } @@ -181,7 +185,7 @@ impl Request<'_> { #[cfg(test)] mod tests { - use std::{borrow::Cow, ops::Deref}; + use std::borrow::Cow; use bytes::Bytes; @@ -189,7 +193,7 @@ mod tests { use crate::{ frame::{ request::{ - batch::{Batch, BatchStatement, BatchType}, + batch::{BatchStatement, BatchType, BatchV2}, execute::Execute, query::{Query, QueryParameters}, DeserializableRequest, SerializableRequest, @@ -261,32 +265,39 @@ mod tests { } // Batch - let statements = vec![ - BatchStatement::Query { - text: query.contents, - }, - BatchStatement::Prepared { - id: Cow::Borrowed(&execute.id), - }, - ]; - let batch = Batch { - statements: Cow::Owned(statements), + // Not execute's values, because named values are not supported in batches. + let mut statements_and_values = vec![]; + BatchStatement::Query { + text: query.contents, + } + .serialize(&mut statements_and_values) + .unwrap(); + statements_and_values + .extend_from_slice(&query.parameters.values.element_count().to_be_bytes()); + statements_and_values.extend_from_slice(query.parameters.values.get_contents()); + + BatchStatement::Prepared { + id: Cow::Borrowed(&execute.id), + } + .serialize(&mut statements_and_values) + .unwrap(); + statements_and_values + .extend_from_slice(&query.parameters.values.element_count().to_be_bytes()); + statements_and_values.extend_from_slice(query.parameters.values.get_contents()); + + let batch = BatchV2 { + statements_and_values: Cow::Owned(statements_and_values), batch_type: BatchType::Logged, consistency: Consistency::EachQuorum, serial_consistency: Some(SerialConsistency::LocalSerial), timestamp: Some(32432), - - // Not execute's values, because named values are not supported in batches. - values: vec![ - query.parameters.values.deref().clone(), - query.parameters.values.deref().clone(), - ], + statements_len: 2, }; { let mut buf = Vec::new(); batch.serialize(&mut buf).unwrap(); - let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap(); + let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap(); assert_eq!(&batch_deserialized, &batch); } } @@ -341,24 +352,30 @@ mod tests { } // Batch - let statements = vec![BatchStatement::Query { + let mut statements_and_values = vec![]; + BatchStatement::Query { text: query.contents, - }]; - let batch = Batch { - statements: Cow::Owned(statements), + } + .serialize(&mut statements_and_values) + .unwrap(); + statements_and_values + .extend_from_slice(&query.parameters.values.element_count().to_be_bytes()); + statements_and_values.extend_from_slice(query.parameters.values.get_contents()); + + let batch = BatchV2 { batch_type: BatchType::Logged, consistency: Consistency::EachQuorum, serial_consistency: None, timestamp: None, - - values: vec![query.parameters.values.deref().clone()], + statements_and_values: Cow::Owned(statements_and_values), + statements_len: 1, }; { let mut buf = Vec::new(); batch.serialize(&mut buf).unwrap(); // Sanity check: batch deserializes to the equivalent. - let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap(); + let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap(); assert_eq!(batch, batch_deserialized); // Now modify flags by adding an unknown one. @@ -370,7 +387,7 @@ mod tests { // Unknown flag should lead to frame rejection, as unknown flags can be new protocol extensions // leading to different semantics. - let _parse_error = Batch::deserialize(&mut &buf[..]).unwrap_err(); + let _parse_error = BatchV2::deserialize(&mut &buf[..]).unwrap_err(); } } } diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 0d760dadcb..3c1b46d764 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -288,6 +288,17 @@ pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDese Ok(v) } +pub fn read_short_bytes_to_buffer( + input: &mut &[u8], + out: &mut impl BufMut, +) -> Result<(), LowLevelDeserializationError> { + let len = read_short(input)?; + let v = read_raw_bytes(len.into(), input)?; + write_short(len, out); + out.put_slice(v); + Ok(()) +} + pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { write_int_length(v.len(), buf)?; buf.put_slice(v); @@ -394,6 +405,23 @@ pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), std::num: Ok(()) } +pub(crate) fn read_long_string_to_buff( + input: &mut &[u8], + out: &mut impl BufMut, +) -> Result<(), LowLevelDeserializationError> { + let len = read_int(input)?; + let raw = read_raw_bytes(len.try_into()?, input)?; + + // verify it is a valid string but ignore the out string; we already have the raw bytes + let _ = str::from_utf8(raw)?; + + // now write it out + write_int(len, out); + out.put_slice(raw); + + Ok(()) +} + #[test] fn type_long_string() { let vals = [String::from(""), String::from("hello, world!")]; diff --git a/scylla/src/statement/prepared.rs b/scylla/src/statement/prepared.rs index 0611c8fdb2..a1a8fbf148 100644 --- a/scylla/src/statement/prepared.rs +++ b/scylla/src/statement/prepared.rs @@ -462,14 +462,20 @@ impl PreparedStatement { self.config.execution_profile_handle.as_ref() } - pub(crate) fn bind( + /// Binds values with a reference to a prepared statement + /// + /// This method will serialize the values and thus type erase them on return + pub fn bind( &self, values: &impl SerializeRow, ) -> Result, SerializationError> { BoundStatement::new(Cow::Borrowed(self), values) } - pub(crate) fn into_bind( + /// Binds values with an owned prepared statement + /// + /// This method will serialize the values and thus type erase them on return + pub fn into_bind( self, values: &impl SerializeRow, ) -> Result, SerializationError> { From 89797029906a82234412ed550cf1bd6b4f4964a2 Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Thu, 10 Apr 2025 18:14:43 -0700 Subject: [PATCH 4/6] expose creation of BoundBatch it implements Default same as `Batch`, and it also allows for override of the batch_type same as `Batch` --- scylla/src/statement/batch.rs | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index 99a0e43821..4ef899d122 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -277,7 +277,7 @@ impl<'a: 'b, 'b> From<&'a BatchStatement> } /// A batch with all of its statements bound to values -pub(crate) struct BoundBatch { +pub struct BoundBatch { pub(crate) config: StatementConfig, batch_type: BatchType, pub(crate) buffer: Vec, @@ -287,6 +287,13 @@ pub(crate) struct BoundBatch { } impl BoundBatch { + pub fn new(batch_type: BatchType) -> Self { + Self { + batch_type, + ..Default::default() + } + } + #[allow(clippy::result_large_err)] pub(crate) fn from_batch( batch: &Batch, @@ -295,14 +302,12 @@ impl BoundBatch { let mut bound_batch = BoundBatch { config: batch.config.clone(), batch_type: batch.batch_type, - prepared: HashMap::new(), - buffer: vec![], - first_prepared: None, statements_len: batch.statements.len().try_into().map_err(|_| { ExecutionError::BadQuery(BadQuery::TooManyQueriesInBatchStatement( batch.statements.len(), )) })?, + ..Default::default() }; let mut values = values.batch_values_iter(); @@ -384,17 +389,17 @@ impl BoundBatch { } /// Borrows the execution profile handle associated with this batch. - pub(crate) fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { + pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { self.config.execution_profile_handle.as_ref() } /// Gets the default timestamp for this batch in microseconds. - pub(crate) fn get_timestamp(&self) -> Option { + pub fn get_timestamp(&self) -> Option { self.config.timestamp } /// Gets type of batch. - pub(crate) fn get_type(&self) -> BatchType { + pub fn get_type(&self) -> BatchType { self.batch_type } @@ -453,6 +458,19 @@ fn serialize_statement( Ok(Some(res)) } +impl Default for BoundBatch { + fn default() -> Self { + Self { + config: StatementConfig::default(), + batch_type: BatchType::Logged, + buffer: Vec::new(), + prepared: HashMap::new(), + first_prepared: None, + statements_len: 0, + } + } +} + fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializationError { BatchSerializationError::ValuesAndStatementsLengthMismatch { n_value_lists, From af882832f61df3b5d3223c5190250e17493b6560 Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Fri, 11 Apr 2025 11:02:24 -0700 Subject: [PATCH 5/6] Allow adding statements to a boundbatch and executing it --- scylla/src/client/session.rs | 372 ++----------------- scylla/src/network/connection.rs | 2 +- scylla/src/statement/batch.rs | 247 +++++++++++- scylla/src/statement/bound.rs | 143 ++++++- scylla/src/statement/execute.rs | 72 ++++ scylla/src/statement/mod.rs | 1 + scylla/src/statement/unprepared.rs | 116 +++++- scylla/tests/integration/statements/batch.rs | 65 +++- 8 files changed, 653 insertions(+), 365 deletions(-) create mode 100644 scylla/src/statement/execute.rs diff --git a/scylla/src/client/session.rs b/scylla/src/client/session.rs index 1d1c5a90d1..06731b9f4c 100644 --- a/scylla/src/client/session.rs +++ b/scylla/src/client/session.rs @@ -16,7 +16,6 @@ use crate::errors::{ PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError, }; -use crate::frame::response::result; use crate::network::tls::TlsProvider; use crate::network::{Connection, ConnectionConfig, PoolConfig, VerifiedKeyspaceName}; use crate::observability::driver_tracing::RequestSpan; @@ -31,22 +30,20 @@ use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession}; use crate::policies::speculative_execution; use crate::policies::timestamp_generator::TimestampGenerator; use crate::response::query_result::{MaybeFirstRowError, QueryResult, RowsError}; -use crate::response::{ - Coordinator, NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse, -}; +use crate::response::{Coordinator, NonErrorQueryResponse, PagingState, PagingStateResponse}; use crate::routing::partitioner::PartitionerName; use crate::routing::{Shard, ShardAwarePortRange}; use crate::statement::batch::BoundBatch; use crate::statement::batch::{Batch, BatchStatement}; use crate::statement::bound::BoundStatement; -use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; +use crate::statement::execute::{Execute, ExecutePageable}; +use crate::statement::prepared::PreparedStatement; use crate::statement::unprepared::Statement; -use crate::statement::{Consistency, PageSize, StatementConfig}; +use crate::statement::{Consistency, StatementConfig}; use arc_swap::ArcSwapOption; use futures::future::join_all; use futures::future::try_join_all; use itertools::Itertools; -use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator}; use scylla_cql::serialize::row::SerializeRow; use std::borrow::{Borrow, Cow}; @@ -59,7 +56,7 @@ use std::time::Duration; use tokio::time::timeout; #[cfg(feature = "unstable-cloud")] use tracing::warn; -use tracing::{debug, error, trace, trace_span, Instrument}; +use tracing::{debug, trace, trace_span, Instrument}; use uuid::Uuid; pub(crate) const TABLET_CHANNEL_SIZE: usize = 8192; @@ -486,7 +483,8 @@ impl Session { statement: impl Into, values: impl SerializeRow, ) -> Result { - self.do_query_unpaged(&statement.into(), values).await + let statement = statement.into(); + (&statement, values).execute(self).await } /// Queries a single page from the database, optionally continuing from a saved point. @@ -546,7 +544,9 @@ impl Session { values: impl SerializeRow, paging_state: PagingState, ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - self.do_query_single_page(&statement.into(), values, paging_state) + let statement = statement.into(); + (&statement, values) + .execute_pageable::(self, paging_state) .await } @@ -811,9 +811,9 @@ impl Session { values: impl BatchValues, ) -> Result { let batch = self.last_minute_prepare_batch(batch, &values).await?; - let batch = BoundBatch::from_batch(batch.as_ref(), values)?; - - self.do_batch(&batch).await + BoundBatch::from_batch(batch.as_ref(), values)? + .execute(self) + .await } /// Estabilishes a CQL session with the database @@ -990,145 +990,7 @@ impl Session { Ok(session) } - async fn do_query_unpaged( - &self, - statement: &Statement, - values: impl SerializeRow, - ) -> Result { - let (result, paging_state_response) = self - .query(statement, values, None, PagingState::start()) - .await?; - if !paging_state_response.finished() { - error!("Unpaged unprepared query returned a non-empty paging state! This is a driver-side or server-side bug."); - return Err(ExecutionError::LastAttemptError( - RequestAttemptError::NonfinishedPagingState, - )); - } - Ok(result) - } - - async fn do_query_single_page( - &self, - statement: &Statement, - values: impl SerializeRow, - paging_state: PagingState, - ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - self.query( - statement, - values, - Some(statement.get_validated_page_size()), - paging_state, - ) - .await - } - - /// Sends a request to the database. - /// Optionally continues fetching results from a saved point. - /// - /// This is now an internal method only. - /// - /// Tl;dr: use [Session::query_unpaged], [Session::query_single_page] or [Session::query_iter] instead. - /// - /// The rationale is that we believe that paging is so important concept (and it has shown to be error-prone as well) - /// that we need to require users to make a conscious decision to use paging or not. For that, we expose - /// the aforementioned 3 methods clearly differing in naming and API, so that no unconscious choices about paging - /// should be made. - async fn query( - &self, - statement: &Statement, - values: impl SerializeRow, - page_size: Option, - paging_state: PagingState, - ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - let execution_profile = statement - .get_execution_profile_handle() - .unwrap_or_else(|| self.get_default_execution_profile_handle()) - .access(); - - let statement_info = RoutingInfo { - consistency: statement - .config - .consistency - .unwrap_or(execution_profile.consistency), - serial_consistency: statement - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency), - ..Default::default() - }; - - let span = RequestSpan::new_query(&statement.contents); - let span_ref = &span; - let (run_request_result, coordinator): ( - RunRequestResult, - Coordinator, - ) = self - .run_request( - statement_info, - &statement.config, - execution_profile, - |connection: Arc, - consistency: Consistency, - execution_profile: &ExecutionProfileInner| { - let serial_consistency = statement - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - // Needed to avoid moving query and values into async move block - let values_ref = &values; - let paging_state_ref = &paging_state; - async move { - if values_ref.is_empty() { - span_ref.record_request_size(0); - connection - .query_raw_with_consistency( - statement, - consistency, - serial_consistency, - page_size, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) - } else { - let prepared = connection.prepare(statement).await?; - let bound = prepared.bind(values_ref)?; - span_ref.record_request_size(bound.values.buffer_size()); - connection - .execute_raw_with_consistency( - &bound, - consistency, - serial_consistency, - page_size, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) - } - } - }, - &span, - ) - .instrument(span.span().clone()) - .await?; - - let response = match run_request_result { - RunRequestResult::IgnoredWriteError => NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Void), - tracing_id: None, - warnings: Vec::new(), - }, - RunRequestResult::Completed(response) => response, - }; - - let (result, paging_state_response) = - response.into_query_result_and_paging_state(coordinator)?; - span.record_result_fields(&result); - - Ok((result, paging_state_response)) - } - - async fn handle_set_keyspace_response( + pub(crate) async fn handle_set_keyspace_response( &self, response: &NonErrorQueryResponse, ) -> Result<(), UseKeyspaceError> { @@ -1365,14 +1227,7 @@ impl Session { &self, bound: &BoundStatement<'_>, ) -> Result { - let (result, paging_state) = self.execute(bound, None, PagingState::start()).await?; - if !paging_state.finished() { - error!("Unpaged prepared query returned a non-empty paging state! This is a driver-side or server-side bug."); - return Err(ExecutionError::LastAttemptError( - RequestAttemptError::NonfinishedPagingState, - )); - } - Ok(result) + bound.execute(self).await } async fn do_execute_single_page( @@ -1380,119 +1235,7 @@ impl Session { bound: &BoundStatement<'_>, paging_state: PagingState, ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - let page_size = bound.prepared.get_validated_page_size(); - self.execute(bound, Some(page_size), paging_state).await - } - - /// Sends a prepared request to the database, optionally continuing from a saved point. - /// - /// This is now an internal method only. - /// - /// Tl;dr: use [Session::execute_unpaged], [Session::execute_single_page] or [Session::execute_iter] instead. - /// - /// The rationale is that we believe that paging is so important concept (and it has shown to be error-prone as well) - /// that we need to require users to make a conscious decision to use paging or not. For that, we expose - /// the aforementioned 3 methods clearly differing in naming and API, so that no unconscious choices about paging - /// should be made. - async fn execute( - &self, - bound: &BoundStatement<'_>, - page_size: Option, - paging_state: PagingState, - ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { - let paging_state_ref = &paging_state; - - let (partition_key, token) = bound - .pk_and_token() - .map_err(PartitionKeyError::into_execution_error)? - .unzip(); - - let execution_profile = bound - .prepared - .get_execution_profile_handle() - .unwrap_or_else(|| self.get_default_execution_profile_handle()) - .access(); - - let table_spec = bound.prepared.get_table_spec(); - - let statement_info = RoutingInfo { - consistency: bound - .prepared - .config - .consistency - .unwrap_or(execution_profile.consistency), - serial_consistency: bound - .prepared - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency), - token, - table: table_spec, - is_confirmed_lwt: bound.prepared.is_confirmed_lwt(), - }; - - let span = RequestSpan::new_prepared( - partition_key.as_ref().map(|pk| pk.iter()), - token, - bound.values.buffer_size(), - ); - - if !span.span().is_disabled() { - if let (Some(table_spec), Some(token)) = (statement_info.table, token) { - let cluster_state = self.get_cluster_state(); - let replicas = cluster_state.get_token_endpoints_iter(table_spec, token); - span.record_replicas(replicas) - } - } - - let (run_request_result, coordinator): ( - RunRequestResult, - Coordinator, - ) = self - .run_request( - statement_info, - &bound.prepared.config, - execution_profile, - |connection: Arc, - consistency: Consistency, - execution_profile: &ExecutionProfileInner| { - let serial_consistency = bound - .prepared - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - async move { - connection - .execute_raw_with_consistency( - bound, - consistency, - serial_consistency, - page_size, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) - } - }, - &span, - ) - .instrument(span.span().clone()) - .await?; - - let response = match run_request_result { - RunRequestResult::IgnoredWriteError => NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Void), - tracing_id: None, - warnings: Vec::new(), - }, - RunRequestResult::Completed(response) => response, - }; - - let (result, paging_state_response) = - response.into_query_result_and_paging_state(coordinator)?; - span.record_result_fields(&result); - - Ok((result, paging_state_response)) + bound.execute_pageable::(self, paging_state).await } async fn do_execute_iter( @@ -1516,79 +1259,6 @@ impl Session { .map_err(PagerExecutionError::NextPageError) } - async fn do_batch(&self, batch: &BoundBatch) -> Result { - // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard - // If users batch statements by shard, they will be rewarded with full shard awareness - let execution_profile = batch - .get_execution_profile_handle() - .unwrap_or_else(|| self.get_default_execution_profile_handle()) - .access(); - - let consistency = batch - .config - .consistency - .unwrap_or(execution_profile.consistency); - - let serial_consistency = batch - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - - let (table, token) = batch - .first_prepared - .as_ref() - .and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token))) - .unzip(); - - let statement_info = RoutingInfo { - consistency, - serial_consistency, - token, - table, - is_confirmed_lwt: false, - }; - - let span = RequestSpan::new_batch(); - - let (run_request_result, coordinator): ( - RunRequestResult, - Coordinator, - ) = self - .run_request( - statement_info, - &batch.config, - execution_profile, - |connection: Arc, - consistency: Consistency, - execution_profile: &ExecutionProfileInner| { - let serial_consistency = batch - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - async move { - connection - .batch_with_consistency(batch, consistency, serial_consistency) - .await - .and_then(QueryResponse::into_non_error_query_response) - } - }, - &span, - ) - .instrument(span.span().clone()) - .await?; - - let result = match run_request_result { - RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(coordinator), - RunRequestResult::Completed(non_error_query_response) => { - let result = non_error_query_response.into_query_result(coordinator)?; - span.record_result_fields(&result); - result - } - }; - - Ok(result) - } - /// Prepares all statements within the batch and returns a new batch where every /// statement is prepared. /// /// # Example @@ -1820,10 +1490,10 @@ impl Session { traces_events_query.config.consistency = consistency; traces_events_query.set_page_size(TRACING_QUERY_PAGE_SIZE); - let (traces_session_res, traces_events_res) = tokio::try_join!( - self.do_query_unpaged(&traces_session_query, (tracing_id,)), - self.do_query_unpaged(&traces_events_query, (tracing_id,)) - )?; + let session_query = (&traces_session_query, (tracing_id,)); + let events_query = (&traces_events_query, (tracing_id,)); + let (traces_session_res, traces_events_res) = + tokio::try_join!(session_query.execute(self), events_query.execute(self))?; // Get tracing info let maybe_tracing_info: Option = traces_session_res @@ -1873,7 +1543,7 @@ impl Session { /// On success, this request's result is returned. // I tried to make this closures take a reference instead of an Arc but failed // maybe once async closures get stabilized this can be fixed - async fn run_request<'a, QueryFut>( + pub(crate) async fn run_request<'a, QueryFut>( &'a self, statement_info: RoutingInfo<'a>, statement_config: &'a StatementConfig, diff --git a/scylla/src/network/connection.rs b/scylla/src/network/connection.rs index 8043fa6a81..fb28a164b3 100644 --- a/scylla/src/network/connection.rs +++ b/scylla/src/network/connection.rs @@ -1038,7 +1038,7 @@ impl Connection { consistency, serial_consistency, timestamp, - statements_len: batch.statements_len, + statements_len: batch.statements_len(), }; loop { diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index 4ef899d122..bc234dfd3e 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -9,19 +9,28 @@ use scylla_cql::frame::frame_errors::{ }; use scylla_cql::frame::request; use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator}; -use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues}; +use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues}; use scylla_cql::serialize::{RowWriter, SerializationError}; +use thiserror::Error; +use tracing::Instrument; -use crate::client::execution_profile::ExecutionProfileHandle; +use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner}; +use crate::client::session::{RunRequestResult, Session}; use crate::errors::{BadQuery, ExecutionError, RequestAttemptError}; +use crate::network::Connection; +use crate::observability::driver_tracing::RequestSpan; use crate::observability::history::HistoryListener; use crate::policies::load_balancing::LoadBalancingPolicy; +use crate::policies::load_balancing::RoutingInfo; use crate::policies::retry::RetryPolicy; +use crate::response::query_result::QueryResult; +use crate::response::{Coordinator, NonErrorQueryResponse, QueryResponse}; use crate::routing::Token; use crate::statement::prepared::{PartitionKeyError, PreparedStatement}; use crate::statement::unprepared::Statement; use super::bound::BoundStatement; +use super::execute::Execute; use super::StatementConfig; use super::{Consistency, SerialConsistency}; pub use crate::frame::request::batch::BatchType; @@ -282,8 +291,8 @@ pub struct BoundBatch { batch_type: BatchType, pub(crate) buffer: Vec, pub(crate) prepared: HashMap, - pub(crate) first_prepared: Option<(PreparedStatement, Token)>, - pub(crate) statements_len: u16, + first_prepared: Option<(PreparedStatement, Token)>, + statements_len: u16, } impl BoundBatch { @@ -294,6 +303,20 @@ impl BoundBatch { } } + /// Appends a new statement to the batch. + pub fn append_statement<'p, V: SerializeRow>( + &mut self, + statement: impl Into>, + ) -> Result<(), BoundBatchStatementError> { + let initial_len = self.buffer.len(); + self.raw_append_statement(statement).inspect_err(|_| { + // if we error'd at any point we should put the buffer back to its old length to not + // corrupt the buffer in case the user doesn't drop the boundbatch but instead skips and + // tries with a successful statement later + self.buffer.truncate(initial_len); + }) + } + #[allow(clippy::result_large_err)] pub(crate) fn from_batch( batch: &Batch, @@ -403,6 +426,96 @@ impl BoundBatch { self.batch_type } + pub fn statements_len(&self) -> u16 { + self.statements_len + } + + // **IMPORTANT NOTE**: It is OK for this function to append to the buffer even if it errors + // because the caller will fix the buffer, HOWEVER, it is *NOT OK* for *ANY* other field in + // `self` to be modified if an error occured because the caller will not reset them. + fn raw_append_statement<'p, V: SerializeRow>( + &mut self, + statement: impl Into>, + ) -> Result<(), BoundBatchStatementError> { + let mut statement = statement.into(); + let mut first_prepared = None; + + if self.statements_len == 0 { + // save it into a local variable for now in case a latter steps fails + first_prepared = match statement { + BoundBatchStatement::Bound(ref b) => b + .token()? + .map(|token| (b.prepared.clone().into_owned(), token)), + BoundBatchStatement::Prepared(ps, values) => { + let bound = ps + .into_bind(&values) + .map_err(BatchStatementSerializationError::ValuesSerialiation)?; + let first_prepared = bound + .token()? + .map(|token| (bound.prepared.clone().into_owned(), token)); + // we already serialized it so to avoid re-serializing it, modify the statement to a + // BoundStatement + statement = BoundBatchStatement::Bound(bound); + first_prepared + } + BoundBatchStatement::Query(_) => None, + }; + } + + let stmnt = match &statement { + BoundBatchStatement::Prepared(ps, _) => request::batch::BatchStatement::Prepared { + id: Cow::Borrowed(ps.get_id()), + }, + BoundBatchStatement::Bound(b) => request::batch::BatchStatement::Prepared { + id: Cow::Borrowed(b.prepared.get_id()), + }, + BoundBatchStatement::Query(q) => request::batch::BatchStatement::Query { + text: Cow::Borrowed(&q.contents), + }, + }; + + serialize_statement(stmnt, &mut self.buffer, |writer| match &statement { + BoundBatchStatement::Prepared(ps, values) => { + let ctx = RowSerializationContext::from_prepared(ps.get_prepared_metadata()); + values.serialize(&ctx, writer).map(Some) + } + BoundBatchStatement::Bound(b) => { + writer.append_serialize_row(&b.values); + Ok(Some(())) + } + // query has no values + BoundBatchStatement::Query(_) => Ok(Some(())), + })?; + + let new_statements_len = self + .statements_len + .checked_add(1) + .ok_or(BoundBatchStatementError::TooManyQueriesInBatchStatement)?; + + /*** at this point nothing else should be fallible as we are going to be modifying + * fields that do not get reset ***/ + + self.statements_len = new_statements_len; + + if let Some(first_prepared) = first_prepared { + self.first_prepared = Some(first_prepared); + } + + let prepared = match statement { + BoundBatchStatement::Prepared(ps, _) => Cow::Owned(ps), + BoundBatchStatement::Bound(b) => b.prepared, + BoundBatchStatement::Query(_) => return Ok(()), + }; + + if !self.prepared.contains_key(prepared.get_id()) { + self.prepared + .insert(prepared.get_id().to_owned(), prepared.into_owned()); + } + + Ok(()) + } + + #[allow(clippy::result_large_err)] fn serialize_from_batch_statement( &mut self, statement: &BatchStatement, @@ -477,3 +590,129 @@ fn counts_mismatch_err(n_value_lists: usize, n_statements: u16) -> BatchSerializ n_statements: n_statements as usize, } } + +/// This enum represents a CQL statement, that can be part of batch and its values +#[derive(Clone)] +#[non_exhaustive] +pub enum BoundBatchStatement<'p, V: SerializeRow> { + /// A prepared statement and its not-yet serialized values + Prepared(PreparedStatement, V), + /// A statement whose values have already been bound (and thus serialized) + Bound(BoundStatement<'p>), + /// An unprepared statement with no values + Query(Statement), +} + +impl<'p> From> for BoundBatchStatement<'p, ()> { + fn from(b: BoundStatement<'p>) -> Self { + BoundBatchStatement::Bound(b) + } +} + +impl From<(PreparedStatement, V)> for BoundBatchStatement<'static, V> { + fn from((p, v): (PreparedStatement, V)) -> Self { + BoundBatchStatement::Prepared(p, v) + } +} + +impl From for BoundBatchStatement<'static, ()> { + fn from(s: Statement) -> Self { + BoundBatchStatement::Query(s) + } +} + +impl From<&str> for BoundBatchStatement<'static, ()> { + fn from(s: &str) -> Self { + BoundBatchStatement::Query(Statement::from(s)) + } +} + +/// An error type returned when adding a statement to a bounded batch fails +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum BoundBatchStatementError { + /// Failed to serialize the batch statement + #[error(transparent)] + Statement(#[from] BatchStatementSerializationError), + /// Failed to serialize statement's bound values. + #[error("Failed to calculate partition key")] + PartitionKey(#[from] PartitionKeyError), + /// Too many statements in the batch statement. + #[error("Added statement goes over exceeded max value of 65,535")] + TooManyQueriesInBatchStatement, +} + +impl Execute for BoundBatch { + async fn execute(&self, session: &Session) -> Result { + // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard + // If users batch statements by shard, they will be rewarded with full shard awareness + let execution_profile = self + .get_execution_profile_handle() + .unwrap_or_else(|| session.get_default_execution_profile_handle()) + .access(); + + let consistency = self + .config + .consistency + .unwrap_or(execution_profile.consistency); + + let serial_consistency = self + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + + let (table, token) = self + .first_prepared + .as_ref() + .and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token))) + .unzip(); + + let statement_info = RoutingInfo { + consistency, + serial_consistency, + token, + table, + is_confirmed_lwt: false, + }; + + let span = RequestSpan::new_batch(); + + let (run_request_result, coordinator): ( + RunRequestResult, + Coordinator, + ) = session + .run_request( + statement_info, + &self.config, + execution_profile, + |connection: Arc, + consistency: Consistency, + execution_profile: &ExecutionProfileInner| { + let serial_consistency = self + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + async move { + connection + .batch_with_consistency(self, consistency, serial_consistency) + .await + .and_then(QueryResponse::into_non_error_query_response) + } + }, + &span, + ) + .instrument(span.span().clone()) + .await?; + + let result = match run_request_result { + RunRequestResult::IgnoredWriteError => QueryResult::mock_empty(coordinator), + RunRequestResult::Completed(non_error_query_response) => { + let result = non_error_query_response.into_query_result(coordinator)?; + span.record_result_fields(&result); + result + } + }; + + Ok(result) + } +} diff --git a/scylla/src/statement/bound.rs b/scylla/src/statement/bound.rs index 42a8b1e4ec..e06a7f5ce2 100644 --- a/scylla/src/statement/bound.rs +++ b/scylla/src/statement/bound.rs @@ -1,14 +1,35 @@ -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc}; -use scylla_cql::serialize::{ - row::{SerializeRow, SerializedValues}, - SerializationError, +use scylla_cql::{ + frame::{ + request::query::{PagingState, PagingStateResponse}, + response::NonErrorResponse, + }, + serialize::{ + row::{SerializeRow, SerializedValues}, + SerializationError, + }, + Consistency, }; +use tracing::Instrument; -use crate::routing::Token; +use crate::{ + client::{ + execution_profile::ExecutionProfileInner, + session::{RunRequestResult, Session}, + }, + errors::ExecutionError, + frame::response::result, + network::Connection, + observability::driver_tracing::RequestSpan, + policies::load_balancing::RoutingInfo, + response::{query_result::QueryResult, Coordinator, NonErrorQueryResponse, QueryResponse}, + routing::Token, +}; -use super::prepared::{ - PartitionKey, PartitionKeyError, PartitionKeyExtractionError, PreparedStatement, +use super::{ + execute::ExecutePageable, + prepared::{PartitionKey, PartitionKeyError, PartitionKeyExtractionError, PreparedStatement}, }; /// Represents a statement that already had all its values bound @@ -66,3 +87,111 @@ impl<'p> BoundStatement<'p> { &self.prepared } } + +impl ExecutePageable for BoundStatement<'_> { + async fn execute_pageable( + &self, + session: &Session, + paging_state: PagingState, + ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { + let page_size = if SINGLE_PAGE { + Some(self.prepared.get_validated_page_size()) + } else { + None + }; + + let paging_state_ref = &paging_state; + + let (partition_key, token) = self + .pk_and_token() + .map_err(PartitionKeyError::into_execution_error)? + .unzip(); + + let execution_profile = self + .prepared + .get_execution_profile_handle() + .unwrap_or_else(|| session.get_default_execution_profile_handle()) + .access(); + + let table_spec = self.prepared.get_table_spec(); + + let statement_info = RoutingInfo { + consistency: self + .prepared + .config + .consistency + .unwrap_or(execution_profile.consistency), + serial_consistency: self + .prepared + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency), + token, + table: table_spec, + is_confirmed_lwt: self.prepared.is_confirmed_lwt(), + }; + + let span = RequestSpan::new_prepared( + partition_key.as_ref().map(|pk| pk.iter()), + token, + self.values.buffer_size(), + ); + + if !span.span().is_disabled() { + if let (Some(table_spec), Some(token)) = (statement_info.table, token) { + let cluster_state = session.get_cluster_state(); + let replicas = cluster_state.get_token_endpoints_iter(table_spec, token); + span.record_replicas(replicas) + } + } + + let (run_request_result, coordinator): ( + RunRequestResult, + Coordinator, + ) = session + .run_request( + statement_info, + &self.prepared.config, + execution_profile, + |connection: Arc, + consistency: Consistency, + execution_profile: &ExecutionProfileInner| { + let serial_consistency = self + .prepared + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + async move { + connection + .execute_raw_with_consistency( + self, + consistency, + serial_consistency, + page_size, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } + }, + &span, + ) + .instrument(span.span().clone()) + .await?; + + let response = match run_request_result { + RunRequestResult::IgnoredWriteError => NonErrorQueryResponse { + response: NonErrorResponse::Result(result::Result::Void), + tracing_id: None, + warnings: Vec::new(), + }, + RunRequestResult::Completed(response) => response, + }; + + let (result, paging_state_response) = + response.into_query_result_and_paging_state(coordinator)?; + span.record_result_fields(&result); + + Ok((result, paging_state_response)) + } +} diff --git a/scylla/src/statement/execute.rs b/scylla/src/statement/execute.rs new file mode 100644 index 0000000000..99a9e600d4 --- /dev/null +++ b/scylla/src/statement/execute.rs @@ -0,0 +1,72 @@ +use scylla_cql::frame::request::query::{PagingState, PagingStateResponse}; +use tracing::error; + +use crate::{ + client::session::Session, + errors::{ExecutionError, RequestAttemptError}, + response::query_result::QueryResult, +}; + +// seals the trait to foreign implementations +mod private { + use scylla_cql::serialize::row::SerializeRow; + + use crate::statement::{batch::BoundBatch, bound::BoundStatement, Statement}; + + #[allow(unnameable_types)] + pub trait Sealed {} + + impl Sealed for BoundBatch {} + impl Sealed for BoundStatement<'_> {} + impl Sealed for (&Statement, V) {} +} + +/// A type that can be executed on a [`Session`] without any additional values +/// +/// In practice this means that the statement(s) all already had their values bound. +pub trait Execute: private::Sealed { + /// Executes on the session + fn execute( + &self, + session: &Session, + ) -> impl std::future::Future>; +} + +/// A type that can be executed on a [`Session`], optionally conitnuing froma saved point +/// +/// We believe that paging is such an important concept that we require users to make a conscious +/// decision to use paging or not. For that, we expose three different ways to execute pageable +/// requests: +/// +/// * `Execute::execute`: unpaged and from the start +/// * `ExecutePageable::execute_pageable::`: paginating from a saved point +/// * `ExecutePageable::execute_pageable::`: no pagination from a saved point +pub trait ExecutePageable { + /// Sends a request to the database, optionally continuing from a saved point. + /// + /// If SINGLE_PAGE is set to true then a single page is returned. If SINGLE_PAGE is set to + /// false, then all pages (starting at `paging_state`) are returned + fn execute_pageable( + &self, + session: &Session, + paging_state: PagingState, + ) -> impl std::future::Future>; +} + +impl Execute for T { + /// Executes the pageable type but getting all pages from the start + async fn execute(&self, session: &Session) -> Result { + let (result, paging_state) = self + .execute_pageable::(session, PagingState::start()) + .await?; + + if !paging_state.finished() { + error!("Unpaged query returned a non-empty paging state! This is a driver-side or server-side bug."); + return Err(ExecutionError::LastAttemptError( + RequestAttemptError::NonfinishedPagingState, + )); + } + + Ok(result) + } +} diff --git a/scylla/src/statement/mod.rs b/scylla/src/statement/mod.rs index a03c400f4d..9f6aa3ff44 100644 --- a/scylla/src/statement/mod.rs +++ b/scylla/src/statement/mod.rs @@ -16,6 +16,7 @@ use crate::policies::retry::RetryPolicy; pub mod batch; pub mod bound; +pub mod execute; pub mod prepared; pub mod unprepared; diff --git a/scylla/src/statement/unprepared.rs b/scylla/src/statement/unprepared.rs index b48f10fbe9..4a7fe5a775 100644 --- a/scylla/src/statement/unprepared.rs +++ b/scylla/src/statement/unprepared.rs @@ -1,9 +1,23 @@ +use scylla_cql::frame::request::query::{PagingState, PagingStateResponse}; +use scylla_cql::frame::response::NonErrorResponse; +use scylla_cql::serialize::row::SerializeRow; +use tracing::Instrument; + +use super::execute::ExecutePageable; use super::{PageSize, StatementConfig}; -use crate::client::execution_profile::ExecutionProfileHandle; +use crate::client::execution_profile::{ExecutionProfileHandle, ExecutionProfileInner}; +use crate::client::session::{RunRequestResult, Session}; +use crate::errors::ExecutionError; +use crate::frame::response::result; use crate::frame::types::{Consistency, SerialConsistency}; +use crate::network::Connection; +use crate::observability::driver_tracing::RequestSpan; use crate::observability::history::HistoryListener; use crate::policies::load_balancing::LoadBalancingPolicy; +use crate::policies::load_balancing::RoutingInfo; use crate::policies::retry::RetryPolicy; +use crate::response::query_result::QueryResult; +use crate::response::{Coordinator, NonErrorQueryResponse, QueryResponse}; use std::sync::Arc; use std::time::Duration; @@ -212,3 +226,103 @@ impl<'a> From<&'a str> for Statement { Statement::new(s.to_owned()) } } +impl ExecutePageable for (&Statement, V) { + async fn execute_pageable( + &self, + session: &Session, + paging_state: PagingState, + ) -> Result<(QueryResult, PagingStateResponse), ExecutionError> { + let (statement, values) = self; + let page_size = if SINGLE_PAGE { + Some(statement.get_validated_page_size()) + } else { + None + }; + + let execution_profile = statement + .get_execution_profile_handle() + .unwrap_or_else(|| session.get_default_execution_profile_handle()) + .access(); + + let statement_info = RoutingInfo { + consistency: statement + .config + .consistency + .unwrap_or(execution_profile.consistency), + serial_consistency: statement + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency), + ..Default::default() + }; + + let span = RequestSpan::new_query(&statement.contents); + let span_ref = &span; + let (run_request_result, coordinator): ( + RunRequestResult, + Coordinator, + ) = session + .run_request( + statement_info, + &statement.config, + execution_profile, + |connection: Arc, + consistency: Consistency, + execution_profile: &ExecutionProfileInner| { + let serial_consistency = statement + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + // Needed to avoid moving into async move block + let paging_state_ref = &paging_state; + async move { + if values.is_empty() { + span_ref.record_request_size(0); + connection + .query_raw_with_consistency( + statement, + consistency, + serial_consistency, + page_size, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } else { + let statement = + connection.prepare(statement).await?.into_bind(values)?; + span_ref.record_request_size(statement.values.buffer_size()); + connection + .execute_raw_with_consistency( + &statement, + consistency, + serial_consistency, + page_size, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } + } + }, + &span, + ) + .instrument(span.span().clone()) + .await?; + + let response = match run_request_result { + RunRequestResult::IgnoredWriteError => NonErrorQueryResponse { + response: NonErrorResponse::Result(result::Result::Void), + tracing_id: None, + warnings: Vec::new(), + }, + RunRequestResult::Completed(response) => response, + }; + + let (result, paging_state_response) = + response.into_query_result_and_paging_state(coordinator)?; + span.record_result_fields(&result); + + Ok((result, paging_state_response)) + } +} diff --git a/scylla/tests/integration/statements/batch.rs b/scylla/tests/integration/statements/batch.rs index 9495b08ec4..3e18b286b3 100644 --- a/scylla/tests/integration/statements/batch.rs +++ b/scylla/tests/integration/statements/batch.rs @@ -7,7 +7,8 @@ use scylla::client::session::Session; use scylla::errors::{BadQuery, ExecutionError, RequestAttemptError}; use scylla::frame::frame_errors::{BatchSerializationError, CqlRequestSerializationError}; use scylla::response::query_result::{QueryResult, QueryRowsResult}; -use scylla::statement::batch::{Batch, BatchStatement, BatchType}; +use scylla::statement::batch::{Batch, BatchStatement, BatchType, BoundBatch}; +use scylla::statement::execute::Execute; use scylla::statement::prepared::PreparedStatement; use scylla::statement::unprepared::Statement; use scylla::value::Counter; @@ -639,3 +640,65 @@ async fn test_batch_to_multiple_tables() { .await .unwrap(); } + +#[tokio::test] +async fn test_bound_batch() { + setup_tracing(); + let session = Arc::new(create_new_session_builder().build().await.unwrap()); + let ks = unique_keyspace_name(); + + session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}")).await.unwrap(); + session + .ddl(format!( + "CREATE TABLE IF NOT EXISTS {ks}.t_batch (a int, b int, c text, primary key (a, b))", + )) + .await + .unwrap(); + + let prepared_statement = session + .prepare(format!( + "INSERT INTO {ks}.t_batch (a, b, c) VALUES (?, ?, ?)", + )) + .await + .unwrap(); + + let four_value: i32 = 4; + let hello_value: String = String::from("hello"); + + let bound_statement = prepared_statement + .clone() + .into_bind(&(1_i32, &four_value, hello_value.as_str())) + .unwrap(); + + let mut batch: BoundBatch = Default::default(); + batch + .append_statement((prepared_statement, (1_i32, 2_i32, "abc"))) + .unwrap(); + batch + .append_statement(&format!("INSERT INTO {ks}.t_batch (a, b, c) VALUES (7, 11, '')")[..]) + .unwrap(); + batch.append_statement(bound_statement).unwrap(); + + batch.execute(&session).await.unwrap(); + + let mut results: Vec<(i32, i32, String)> = session + .query_unpaged(format!("SELECT a, b, c FROM {ks}.t_batch"), &[]) + .await + .unwrap() + .into_rows_result() + .unwrap() + .rows::<(i32, i32, String)>() + .unwrap() + .collect::>() + .unwrap(); + + results.sort(); + assert_eq!( + results, + vec![ + (1, 2, String::from("abc")), + (1, 4, String::from("hello")), + (7, 11, String::from("")) + ] + ); +} From 2b5f6c6183ab20a1edf8c099d55f4a6554790f8f Mon Sep 17 00:00:00 2001 From: Andres Medina Date: Thu, 1 May 2025 17:52:57 -0700 Subject: [PATCH 6/6] add BoundStatement::into_owned allows users to transform an existing bound statement into one that doesn't borrow the prepared statement (by cloning it) --- scylla/src/statement/bound.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scylla/src/statement/bound.rs b/scylla/src/statement/bound.rs index e06a7f5ce2..7831e435f4 100644 --- a/scylla/src/statement/bound.rs +++ b/scylla/src/statement/bound.rs @@ -74,6 +74,14 @@ impl<'p> BoundStatement<'p> { Ok(Some((partition_key, token))) } + /// Consumes this bound statement to return one with no borrowed data + pub fn into_owned(self) -> BoundStatement<'static> { + BoundStatement { + prepared: Cow::Owned(self.prepared.into_owned()), + values: self.values, + } + } + /// Calculates the token for the prepared statement and its bound values /// /// Returns the token that would be computed for executing the provided prepared statement with