Skip to content

Commit 5f3b5bf

Browse files
authored
Merge pull request #1329 from wprzytula/construct-prepared-statement-only-where-needed
Construct `PreparedStatement` only where needed
2 parents 9721530 + 5dcd15b commit 5f3b5bf

File tree

4 files changed

+111
-39
lines changed

4 files changed

+111
-39
lines changed

scylla/src/client/session.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
3030
use crate::policies::speculative_execution;
3131
use crate::policies::timestamp_generator::TimestampGenerator;
3232
use crate::response::query_result::{MaybeFirstRowError, QueryResult, RowsError};
33-
use crate::response::{NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse};
33+
use crate::response::{
34+
NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse, RawPreparedStatement,
35+
};
3436
use crate::routing::partitioner::PartitionerName;
3537
use crate::routing::{Shard, ShardAwarePortRange};
3638
use crate::statement::batch::batch_values;
@@ -1064,15 +1066,14 @@ impl Session {
10641066
.serial_consistency
10651067
.unwrap_or(execution_profile.serial_consistency);
10661068
// Needed to avoid moving query and values into async move block
1067-
let statement_ref = &statement;
10681069
let values_ref = &values;
10691070
let paging_state_ref = &paging_state;
10701071
async move {
10711072
if values_ref.is_empty() {
10721073
span_ref.record_request_size(0);
10731074
connection
10741075
.query_raw_with_consistency(
1075-
statement_ref,
1076+
statement,
10761077
consistency,
10771078
serial_consistency,
10781079
page_size,
@@ -1081,7 +1082,7 @@ impl Session {
10811082
.await
10821083
.and_then(QueryResponse::into_non_error_query_response)
10831084
} else {
1084-
let prepared = connection.prepare(statement_ref).await?;
1085+
let prepared = connection.prepare(statement).await?;
10851086
let serialized = prepared.serialize_values(values_ref)?;
10861087
span_ref.record_request_size(serialized.buffer_size());
10871088
connection
@@ -1242,29 +1243,28 @@ impl Session {
12421243
let connections_iter = cluster_state.iter_working_connections()?;
12431244

12441245
// Prepare statements on all connections concurrently
1245-
let handles = connections_iter.map(|c| async move { c.prepare(statement_ref).await });
1246+
let handles = connections_iter.map(|c| async move { c.prepare_raw(statement_ref).await });
12461247
let mut results = join_all(handles).await.into_iter();
12471248

12481249
// If at least one prepare was successful, `prepare()` returns Ok.
12491250
// Find the first result that is Ok, or Err if all failed.
12501251

12511252
// Safety: there is at least one node in the cluster, and `Cluster::iter_working_connections()`
12521253
// returns either an error or an iterator with at least one connection, so there will be at least one result.
1253-
let first_ok: Result<PreparedStatement, RequestAttemptError> =
1254+
let first_ok: Result<RawPreparedStatement, RequestAttemptError> =
12541255
results.by_ref().find_or_first(Result::is_ok).unwrap();
1255-
let mut prepared: PreparedStatement =
1256-
first_ok.map_err(|first_attempt| PrepareError::AllAttemptsFailed { first_attempt })?;
1256+
let mut prepared: PreparedStatement = first_ok
1257+
.map_err(|first_attempt| PrepareError::AllAttemptsFailed { first_attempt })?
1258+
.into_prepared_statement();
12571259

12581260
// Validate prepared ids equality
12591261
for statement in results.flatten() {
1260-
if prepared.get_id() != statement.get_id() {
1262+
if prepared.get_id() != &statement.prepared_response.id {
12611263
return Err(PrepareError::PreparedStatementIdsMismatch);
12621264
}
12631265

12641266
// Collect all tracing ids from prepare() queries in the final result
1265-
prepared
1266-
.prepare_tracing_ids
1267-
.extend(statement.prepare_tracing_ids);
1267+
prepared.prepare_tracing_ids.extend(statement.tracing_id);
12681268
}
12691269

12701270
prepared.set_partitioner_name(

scylla/src/network/connection.rs

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::policies::timestamp_generator::TimestampGenerator;
2424
use crate::response::query_result::QueryResult;
2525
use crate::response::{
2626
NonErrorAuthResponse, NonErrorStartupResponse, PagingState, PagingStateResponse, QueryResponse,
27+
RawPreparedStatement,
2728
};
2829
use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
2930
use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
@@ -632,46 +633,56 @@ impl Connection {
632633
Ok(supported)
633634
}
634635

635-
pub(crate) async fn prepare(
636+
/// Prepares the given statement and returns raw parts that can be used to construct
637+
/// a prepared statement.
638+
///
639+
/// Extracted in order to avoid needless allocations upon [PreparedStatement] construction
640+
/// in cases when the [PreparedStatement] is not used anyway.
641+
pub(crate) async fn prepare_raw<'statement>(
636642
&self,
637-
query: &Statement,
638-
) -> Result<PreparedStatement, RequestAttemptError> {
643+
statement: &'statement Statement,
644+
) -> Result<RawPreparedStatement<'statement>, RequestAttemptError> {
639645
let query_response = self
640646
.send_request(
641647
&request::Prepare {
642-
query: &query.contents,
648+
query: &statement.contents,
643649
},
644650
true,
645-
query.config.tracing,
651+
statement.config.tracing,
646652
None,
647653
)
648654
.await?;
649655

650-
let mut prepared_statement = match query_response.response {
656+
match query_response.response {
651657
Response::Error(error::Error { error, reason }) => {
652-
return Err(RequestAttemptError::DbError(error, reason))
658+
Err(RequestAttemptError::DbError(error, reason))
653659
}
654-
Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
655-
p.id,
656-
self.features
660+
Response::Result(result::Result::Prepared(p)) => {
661+
let is_lwt = self
662+
.features
657663
.protocol_features
658-
.prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32),
659-
p.prepared_metadata,
660-
Arc::new(p.result_metadata),
661-
query.contents.clone(),
662-
query.get_validated_page_size(),
663-
query.config.clone(),
664-
),
665-
_ => {
666-
return Err(RequestAttemptError::UnexpectedResponse(
667-
query_response.response.to_response_kind(),
664+
.prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32);
665+
Ok(RawPreparedStatement::new(
666+
statement,
667+
p,
668+
is_lwt,
669+
query_response.tracing_id,
668670
))
669671
}
670-
};
671-
672-
if let Some(tracing_id) = query_response.tracing_id {
673-
prepared_statement.prepare_tracing_ids.push(tracing_id);
672+
_ => Err(RequestAttemptError::UnexpectedResponse(
673+
query_response.response.to_response_kind(),
674+
)),
674675
}
676+
}
677+
678+
/// Prepares the given statement and returns [PreparedStatement].
679+
pub(crate) async fn prepare(
680+
&self,
681+
statement: &Statement,
682+
) -> Result<PreparedStatement, RequestAttemptError> {
683+
let raw_prepared_statement = self.prepare_raw(statement).await?;
684+
let prepared_statement = raw_prepared_statement.into_prepared_statement();
685+
675686
Ok(prepared_statement)
676687
}
677688

@@ -681,14 +692,15 @@ impl Connection {
681692
previous_prepared: &PreparedStatement,
682693
) -> Result<(), RequestAttemptError> {
683694
let reprepare_query: Statement = query.into();
684-
let reprepared = self.prepare(&reprepare_query).await?;
695+
let prepared_response = self.prepare_raw(&reprepare_query).await?.prepared_response;
696+
685697
// Reprepared statement should keep its id - it's the md5 sum
686698
// of statement contents
687-
if reprepared.get_id() != previous_prepared.get_id() {
699+
if prepared_response.id != previous_prepared.get_id() {
688700
Err(RequestAttemptError::RepreparedIdChanged {
701+
reprepared_id: prepared_response.id.into(),
689702
statement: reprepare_query.contents,
690703
expected_id: previous_prepared.get_id().clone().into(),
691-
reprepared_id: reprepared.get_id().clone().into(),
692704
})
693705
} else {
694706
Ok(())

scylla/src/response/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ mod request_response;
1212

1313
pub(crate) use request_response::{
1414
NonErrorAuthResponse, NonErrorQueryResponse, NonErrorStartupResponse, QueryResponse,
15+
RawPreparedStatement,
1516
};
1617
pub use scylla_cql::frame::request::query::{PagingState, PagingStateResponse};

scylla/src/response/request_response.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashMap;
2+
use std::sync::Arc;
23

34
use bytes::Bytes;
45
use scylla_cql::frame::request::query::PagingStateResponse;
@@ -9,6 +10,8 @@ use uuid::Uuid;
910
use crate::errors::RequestAttemptError;
1011
use crate::frame::response::{self, result};
1112
use crate::response::query_result::QueryResult;
13+
use crate::statement::prepared::PreparedStatement;
14+
use crate::statement::Statement;
1215

1316
pub(crate) struct QueryResponse {
1417
pub(crate) response: Response,
@@ -108,3 +111,59 @@ pub(crate) enum NonErrorAuthResponse {
108111
AuthChallenge(response::authenticate::AuthChallenge),
109112
AuthSuccess(response::authenticate::AuthSuccess),
110113
}
114+
115+
/// Parts which are needed to construct [PreparedStatement].
116+
///
117+
/// Kept separate for performance reasons, because constructing
118+
/// [PreparedStatement] involves allocations.
119+
pub(crate) struct RawPreparedStatement<'statement> {
120+
pub(crate) statement: &'statement Statement,
121+
pub(crate) prepared_response: result::Prepared,
122+
pub(crate) is_lwt: bool,
123+
pub(crate) tracing_id: Option<Uuid>,
124+
}
125+
126+
impl<'statement> RawPreparedStatement<'statement> {
127+
pub(crate) fn new(
128+
statement: &'statement Statement,
129+
prepared_response: result::Prepared,
130+
is_lwt: bool,
131+
tracing_id: Option<Uuid>,
132+
) -> Self {
133+
Self {
134+
statement,
135+
prepared_response,
136+
is_lwt,
137+
tracing_id,
138+
}
139+
}
140+
}
141+
142+
/// Constructs the fully-fledged [PreparedStatement].
143+
///
144+
/// This involves allocations.
145+
impl RawPreparedStatement<'_> {
146+
pub(crate) fn into_prepared_statement(self) -> PreparedStatement {
147+
let Self {
148+
statement,
149+
prepared_response,
150+
is_lwt,
151+
tracing_id,
152+
} = self;
153+
let mut prepared_statement = PreparedStatement::new(
154+
prepared_response.id,
155+
is_lwt,
156+
prepared_response.prepared_metadata,
157+
Arc::new(prepared_response.result_metadata),
158+
statement.contents.clone(),
159+
statement.get_validated_page_size(),
160+
statement.config.clone(),
161+
);
162+
163+
if let Some(tracing_id) = tracing_id {
164+
prepared_statement.prepare_tracing_ids.push(tracing_id);
165+
}
166+
167+
prepared_statement
168+
}
169+
}

0 commit comments

Comments
 (0)