Skip to content

Commit 58c5da5

Browse files
committed
Add an internal only BoundBatch
1 parent f2385d6 commit 58c5da5

File tree

6 files changed

+311
-255
lines changed

6 files changed

+311
-255
lines changed

scylla-cql/src/frame/request/batch.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,52 @@ where
3535
pub values: Values,
3636
}
3737

38+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
39+
pub struct BatchV2<'b> {
40+
pub statements_and_values: Cow<'b, [u8]>,
41+
pub batch_type: BatchType,
42+
pub consistency: types::Consistency,
43+
pub serial_consistency: Option<types::SerialConsistency>,
44+
pub timestamp: Option<i64>,
45+
pub statements_len: u16,
46+
}
47+
48+
impl SerializableRequest for BatchV2<'_> {
49+
const OPCODE: RequestOpcode = RequestOpcode::Batch;
50+
51+
fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
52+
// Serializing type of batch
53+
buf.put_u8(self.batch_type as u8);
54+
55+
// Serializing queries
56+
types::write_short(self.statements_len, buf);
57+
buf.extend_from_slice(&self.statements_and_values);
58+
59+
// Serializing consistency
60+
types::write_consistency(self.consistency, buf);
61+
62+
// Serializing flags
63+
let mut flags = 0;
64+
if self.serial_consistency.is_some() {
65+
flags |= FLAG_WITH_SERIAL_CONSISTENCY;
66+
}
67+
if self.timestamp.is_some() {
68+
flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
69+
}
70+
71+
buf.put_u8(flags);
72+
73+
if let Some(serial_consistency) = self.serial_consistency {
74+
types::write_serial_consistency(serial_consistency, buf);
75+
}
76+
if let Some(timestamp) = self.timestamp {
77+
types::write_long(timestamp, buf);
78+
}
79+
80+
Ok(())
81+
}
82+
}
83+
3884
/// The type of a batch.
3985
#[derive(Clone, Copy)]
4086
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
@@ -208,7 +254,7 @@ impl BatchStatement<'_> {
208254
}
209255

210256
impl BatchStatement<'_> {
211-
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
257+
pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
212258
match self {
213259
Self::Query { text } => {
214260
buf.put_u8(0);

scylla/src/client/session.rs

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint;
1212
use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
1313
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
1414
use crate::errors::{
15-
BadQuery, BrokenConnectionError, ExecutionError, MetadataError, NewSessionError,
16-
PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError,
17-
TracingError, UseKeyspaceError,
15+
BrokenConnectionError, ExecutionError, MetadataError, NewSessionError, PagerExecutionError,
16+
PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, TracingError,
17+
UseKeyspaceError,
1818
};
1919
use crate::frame::response::result;
2020
use crate::network::tls::TlsProvider;
@@ -36,7 +36,7 @@ use crate::response::{
3636
};
3737
use crate::routing::partitioner::PartitionerName;
3838
use crate::routing::{Shard, ShardAwarePortRange};
39-
use crate::statement::batch::batch_values;
39+
use crate::statement::batch::BoundBatch;
4040
use crate::statement::batch::{Batch, BatchStatement};
4141
use crate::statement::bound::BoundStatement;
4242
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
@@ -47,9 +47,10 @@ use futures::future::join_all;
4747
use futures::future::try_join_all;
4848
use itertools::Itertools;
4949
use scylla_cql::frame::response::NonErrorResponse;
50-
use scylla_cql::serialize::batch::BatchValues;
50+
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
5151
use scylla_cql::serialize::row::SerializeRow;
52-
use std::borrow::Borrow;
52+
use std::borrow::{Borrow, Cow};
53+
use std::collections::{HashMap, HashSet};
5354
use std::future::Future;
5455
use std::net::{IpAddr, SocketAddr};
5556
use std::num::NonZeroU32;
@@ -809,7 +810,10 @@ impl Session {
809810
batch: &Batch,
810811
values: impl BatchValues,
811812
) -> Result<QueryResult, ExecutionError> {
812-
self.do_batch(batch, values).await
813+
let batch = self.last_minute_prepare_batch(batch, &values).await?;
814+
let batch = BoundBatch::from_batch(batch.as_ref(), values)?;
815+
816+
self.do_batch(&batch).await
813817
}
814818

815819
/// Estabilishes a CQL session with the database
@@ -1185,7 +1189,10 @@ impl Session {
11851189
// Making QueryPager::new_for_query work with values is too hard (if even possible)
11861190
// so instead of sending one prepare to a specific connection on each iterator query,
11871191
// we fully prepare a statement beforehand.
1188-
let bound = self.prepare_nongeneric(&statement).await?.into_bind(&values)?;
1192+
let bound = self
1193+
.prepare_nongeneric(&statement)
1194+
.await?
1195+
.into_bind(&values)?;
11891196
QueryPager::new_for_prepared_statement(PreparedPagerConfig {
11901197
bound,
11911198
execution_profile,
@@ -1509,22 +1516,9 @@ impl Session {
15091516
.map_err(PagerExecutionError::NextPageError)
15101517
}
15111518

1512-
async fn do_batch(
1513-
&self,
1514-
batch: &Batch,
1515-
values: impl BatchValues,
1516-
) -> Result<QueryResult, ExecutionError> {
1519+
async fn do_batch(&self, batch: &BoundBatch) -> Result<QueryResult, ExecutionError> {
15171520
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
15181521
// If users batch statements by shard, they will be rewarded with full shard awareness
1519-
1520-
// check to ensure that we don't send a batch statement with more than u16::MAX queries
1521-
let batch_statements_length = batch.statements.len();
1522-
if batch_statements_length > u16::MAX as usize {
1523-
return Err(ExecutionError::BadQuery(
1524-
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
1525-
));
1526-
}
1527-
15281522
let execution_profile = batch
15291523
.get_execution_profile_handle()
15301524
.unwrap_or_else(|| self.get_default_execution_profile_handle())
@@ -1540,22 +1534,17 @@ impl Session {
15401534
.serial_consistency
15411535
.unwrap_or(execution_profile.serial_consistency);
15421536

1543-
let (first_value_token, values) =
1544-
batch_values::peek_first_token(values, batch.statements.first())?;
1545-
let values_ref = &values;
1546-
1547-
let table_spec =
1548-
if let Some(BatchStatement::PreparedStatement(ps)) = batch.statements.first() {
1549-
ps.get_table_spec()
1550-
} else {
1551-
None
1552-
};
1537+
let (table, token) = batch
1538+
.first_prepared
1539+
.as_ref()
1540+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
1541+
.unzip();
15531542

15541543
let statement_info = RoutingInfo {
15551544
consistency,
15561545
serial_consistency,
1557-
token: first_value_token,
1558-
table: table_spec,
1546+
token,
1547+
table,
15591548
is_confirmed_lwt: false,
15601549
};
15611550

@@ -1578,12 +1567,7 @@ impl Session {
15781567
.unwrap_or(execution_profile.serial_consistency);
15791568
async move {
15801569
connection
1581-
.batch_with_consistency(
1582-
batch,
1583-
values_ref,
1584-
consistency,
1585-
serial_consistency,
1586-
)
1570+
.batch_with_consistency(batch, consistency, serial_consistency)
15871571
.await
15881572
.and_then(QueryResponse::into_non_error_query_response)
15891573
}
@@ -1652,6 +1636,54 @@ impl Session {
16521636
Ok(prepared_batch)
16531637
}
16541638

1639+
async fn last_minute_prepare_batch<'b>(
1640+
&self,
1641+
init_batch: &'b Batch,
1642+
values: impl BatchValues,
1643+
) -> Result<Cow<'b, Batch>, PrepareError> {
1644+
let mut to_prepare = HashSet::<&str>::new();
1645+
1646+
{
1647+
let mut values_iter = values.batch_values_iter();
1648+
for stmt in &init_batch.statements {
1649+
if let BatchStatement::Query(query) = stmt {
1650+
if let Some(false) = values_iter.is_empty_next() {
1651+
to_prepare.insert(&query.contents);
1652+
}
1653+
} else {
1654+
values_iter.skip_next();
1655+
}
1656+
}
1657+
}
1658+
1659+
if to_prepare.is_empty() {
1660+
return Ok(Cow::Borrowed(init_batch));
1661+
}
1662+
1663+
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1664+
1665+
for query in to_prepare {
1666+
let prepared = self.prepare(query).await?;
1667+
prepared_queries.insert(query, prepared);
1668+
}
1669+
1670+
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1671+
for stmt in &init_batch.statements {
1672+
match stmt {
1673+
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1674+
{
1675+
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1676+
None => batch.to_mut().append_statement(query.clone()),
1677+
},
1678+
BatchStatement::PreparedStatement(prepared) => {
1679+
batch.to_mut().append_statement(prepared.clone());
1680+
}
1681+
}
1682+
}
1683+
1684+
Ok(batch)
1685+
}
1686+
16551687
/// Sends `USE <keyspace_name>` request on all connections\
16561688
/// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
16571689
///

scylla/src/network/connection.rs

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::response::{
2727
};
2828
use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
2929
use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
30-
use crate::statement::batch::{Batch, BatchStatement};
30+
use crate::statement::batch::BoundBatch;
3131
use crate::statement::bound::BoundStatement;
3232
use crate::statement::prepared::PreparedStatement;
3333
use crate::statement::unprepared::Statement;
@@ -42,12 +42,10 @@ use scylla_cql::frame::response::result::{ResultMetadata, TableSpec};
4242
use scylla_cql::frame::response::Error;
4343
use scylla_cql::frame::response::{self, error};
4444
use scylla_cql::frame::types::SerialConsistency;
45-
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
46-
use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter;
47-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
45+
use scylla_cql::serialize::row::SerializedValues;
4846
use socket2::{SockRef, TcpKeepalive};
4947
use std::borrow::Cow;
50-
use std::collections::{BTreeSet, HashMap, HashSet};
48+
use std::collections::{BTreeSet, HashMap};
5149
use std::convert::TryFrom;
5250
use std::net::{IpAddr, SocketAddr};
5351
use std::num::NonZeroU64;
@@ -1022,22 +1020,10 @@ impl Connection {
10221020

10231021
pub(crate) async fn batch_with_consistency(
10241022
&self,
1025-
init_batch: &Batch,
1026-
values: impl BatchValues,
1023+
batch: &BoundBatch,
10271024
consistency: Consistency,
10281025
serial_consistency: Option<SerialConsistency>,
10291026
) -> Result<QueryResponse, RequestAttemptError> {
1030-
let batch = self.prepare_batch(init_batch, &values).await?;
1031-
1032-
let contexts = batch.statements.iter().map(|bs| match bs {
1033-
BatchStatement::Query(_) => RowSerializationContext::empty(),
1034-
BatchStatement::PreparedStatement(ps) => {
1035-
RowSerializationContext::from_prepared(ps.get_prepared_metadata())
1036-
}
1037-
});
1038-
1039-
let values = RawBatchValuesAdapter::new(values, contexts);
1040-
10411027
let get_timestamp_from_gen = || {
10421028
self.config
10431029
.timestamp_generator
@@ -1046,13 +1032,13 @@ impl Connection {
10461032
};
10471033
let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen);
10481034

1049-
let batch_frame = batch::Batch {
1050-
statements: Cow::Borrowed(&batch.statements),
1051-
values,
1035+
let batch_frame = batch::BatchV2 {
1036+
statements_and_values: Cow::Borrowed(&batch.buffer),
10521037
batch_type: batch.get_type(),
10531038
consistency,
10541039
serial_consistency,
10551040
timestamp,
1041+
statements_len: batch.statements_len,
10561042
};
10571043

10581044
loop {
@@ -1065,13 +1051,8 @@ impl Connection {
10651051
Response::Error(err) => match err.error {
10661052
DbError::Unprepared { statement_id } => {
10671053
debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1068-
let prepared_statement = batch.statements.iter().find_map(|s| match s {
1069-
BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
1070-
Some(s)
1071-
}
1072-
_ => None,
1073-
});
1074-
if let Some(p) = prepared_statement {
1054+
1055+
if let Some(p) = batch.prepared.get(&statement_id) {
10751056
self.reprepare(p.get_statement(), p).await?;
10761057
continue;
10771058
} else {
@@ -1088,54 +1069,6 @@ impl Connection {
10881069
}
10891070
}
10901071

1091-
async fn prepare_batch<'b>(
1092-
&self,
1093-
init_batch: &'b Batch,
1094-
values: impl BatchValues,
1095-
) -> Result<Cow<'b, Batch>, RequestAttemptError> {
1096-
let mut to_prepare = HashSet::<&str>::new();
1097-
1098-
{
1099-
let mut values_iter = values.batch_values_iter();
1100-
for stmt in &init_batch.statements {
1101-
if let BatchStatement::Query(query) = stmt {
1102-
if let Some(false) = values_iter.is_empty_next() {
1103-
to_prepare.insert(&query.contents);
1104-
}
1105-
} else {
1106-
values_iter.skip_next();
1107-
}
1108-
}
1109-
}
1110-
1111-
if to_prepare.is_empty() {
1112-
return Ok(Cow::Borrowed(init_batch));
1113-
}
1114-
1115-
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1116-
1117-
for query in &to_prepare {
1118-
let prepared = self.prepare(&Statement::new(query.to_string())).await?;
1119-
prepared_queries.insert(query, prepared);
1120-
}
1121-
1122-
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1123-
for stmt in &init_batch.statements {
1124-
match stmt {
1125-
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1126-
{
1127-
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1128-
None => batch.to_mut().append_statement(query.clone()),
1129-
},
1130-
BatchStatement::PreparedStatement(prepared) => {
1131-
batch.to_mut().append_statement(prepared.clone());
1132-
}
1133-
}
1134-
}
1135-
1136-
Ok(batch)
1137-
}
1138-
11391072
pub(super) async fn use_keyspace(
11401073
&self,
11411074
keyspace_name: &VerifiedKeyspaceName,

0 commit comments

Comments
 (0)