Skip to content

Commit 47ab822

Browse files
committed
Add an internal only BoundBatch
1 parent 2061d9b commit 47ab822

File tree

6 files changed

+311
-255
lines changed

6 files changed

+311
-255
lines changed

scylla-cql/src/frame/request/batch.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,52 @@ where
3535
pub values: Values,
3636
}
3737

38+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
39+
pub struct BatchV2<'b> {
40+
pub statements_and_values: Cow<'b, [u8]>,
41+
pub batch_type: BatchType,
42+
pub consistency: types::Consistency,
43+
pub serial_consistency: Option<types::SerialConsistency>,
44+
pub timestamp: Option<i64>,
45+
pub statements_len: u16,
46+
}
47+
48+
impl SerializableRequest for BatchV2<'_> {
49+
const OPCODE: RequestOpcode = RequestOpcode::Batch;
50+
51+
fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
52+
// Serializing type of batch
53+
buf.put_u8(self.batch_type as u8);
54+
55+
// Serializing queries
56+
types::write_short(self.statements_len, buf);
57+
buf.extend_from_slice(&self.statements_and_values);
58+
59+
// Serializing consistency
60+
types::write_consistency(self.consistency, buf);
61+
62+
// Serializing flags
63+
let mut flags = 0;
64+
if self.serial_consistency.is_some() {
65+
flags |= FLAG_WITH_SERIAL_CONSISTENCY;
66+
}
67+
if self.timestamp.is_some() {
68+
flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
69+
}
70+
71+
buf.put_u8(flags);
72+
73+
if let Some(serial_consistency) = self.serial_consistency {
74+
types::write_serial_consistency(serial_consistency, buf);
75+
}
76+
if let Some(timestamp) = self.timestamp {
77+
types::write_long(timestamp, buf);
78+
}
79+
80+
Ok(())
81+
}
82+
}
83+
3884
/// The type of a batch.
3985
#[derive(Clone, Copy)]
4086
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
@@ -208,7 +254,7 @@ impl BatchStatement<'_> {
208254
}
209255

210256
impl BatchStatement<'_> {
211-
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
257+
pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
212258
match self {
213259
Self::Query { text } => {
214260
buf.put_u8(0);

scylla/src/client/session.rs

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use crate::cluster::node::CloudEndpoint;
1212
use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
1313
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
1414
use crate::errors::{
15-
BadQuery, BrokenConnectionError, ExecutionError, MetadataError, NewSessionError,
16-
PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError,
17-
TracingError, UseKeyspaceError,
15+
BrokenConnectionError, ExecutionError, MetadataError, NewSessionError, PagerExecutionError,
16+
PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, TracingError,
17+
UseKeyspaceError,
1818
};
1919
use crate::frame::response::result;
2020
use crate::network::tls::TlsProvider;
@@ -36,7 +36,7 @@ use crate::response::{
3636
};
3737
use crate::routing::partitioner::PartitionerName;
3838
use crate::routing::{Shard, ShardAwarePortRange};
39-
use crate::statement::batch::batch_values;
39+
use crate::statement::batch::BoundBatch;
4040
use crate::statement::batch::{Batch, BatchStatement};
4141
use crate::statement::bound::BoundStatement;
4242
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
@@ -47,9 +47,10 @@ use futures::future::join_all;
4747
use futures::future::try_join_all;
4848
use itertools::Itertools;
4949
use scylla_cql::frame::response::NonErrorResponse;
50-
use scylla_cql::serialize::batch::BatchValues;
50+
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
5151
use scylla_cql::serialize::row::SerializeRow;
52-
use std::borrow::Borrow;
52+
use std::borrow::{Borrow, Cow};
53+
use std::collections::{HashMap, HashSet};
5354
use std::future::Future;
5455
use std::net::{IpAddr, SocketAddr};
5556
use std::num::NonZeroU32;
@@ -809,7 +810,10 @@ impl Session {
809810
batch: &Batch,
810811
values: impl BatchValues,
811812
) -> Result<QueryResult, ExecutionError> {
812-
self.do_batch(batch, values).await
813+
let batch = self.last_minute_prepare_batch(batch, &values).await?;
814+
let batch = BoundBatch::from_batch(batch.as_ref(), values)?;
815+
816+
self.do_batch(&batch).await
813817
}
814818

815819
/// Estabilishes a CQL session with the database
@@ -1189,7 +1193,10 @@ impl Session {
11891193
// Making QueryPager::new_for_query work with values is too hard (if even possible)
11901194
// so instead of sending one prepare to a specific connection on each iterator query,
11911195
// we fully prepare a statement beforehand.
1192-
let bound = self.prepare_nongeneric(&statement).await?.into_bind(&values)?;
1196+
let bound = self
1197+
.prepare_nongeneric(&statement)
1198+
.await?
1199+
.into_bind(&values)?;
11931200
QueryPager::new_for_prepared_statement(PreparedPagerConfig {
11941201
bound,
11951202
execution_profile,
@@ -1517,22 +1524,9 @@ impl Session {
15171524
.map_err(PagerExecutionError::NextPageError)
15181525
}
15191526

1520-
async fn do_batch(
1521-
&self,
1522-
batch: &Batch,
1523-
values: impl BatchValues,
1524-
) -> Result<QueryResult, ExecutionError> {
1527+
async fn do_batch(&self, batch: &BoundBatch) -> Result<QueryResult, ExecutionError> {
15251528
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
15261529
// If users batch statements by shard, they will be rewarded with full shard awareness
1527-
1528-
// check to ensure that we don't send a batch statement with more than u16::MAX queries
1529-
let batch_statements_length = batch.statements.len();
1530-
if batch_statements_length > u16::MAX as usize {
1531-
return Err(ExecutionError::BadQuery(
1532-
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
1533-
));
1534-
}
1535-
15361530
let execution_profile = batch
15371531
.get_execution_profile_handle()
15381532
.unwrap_or_else(|| self.get_default_execution_profile_handle())
@@ -1548,22 +1542,17 @@ impl Session {
15481542
.serial_consistency
15491543
.unwrap_or(execution_profile.serial_consistency);
15501544

1551-
let (first_value_token, values) =
1552-
batch_values::peek_first_token(values, batch.statements.first())?;
1553-
let values_ref = &values;
1554-
1555-
let table_spec =
1556-
if let Some(BatchStatement::PreparedStatement(ps)) = batch.statements.first() {
1557-
ps.get_table_spec()
1558-
} else {
1559-
None
1560-
};
1545+
let (table, token) = batch
1546+
.first_prepared
1547+
.as_ref()
1548+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
1549+
.unzip();
15611550

15621551
let statement_info = RoutingInfo {
15631552
consistency,
15641553
serial_consistency,
1565-
token: first_value_token,
1566-
table: table_spec,
1554+
token,
1555+
table,
15671556
is_confirmed_lwt: false,
15681557
};
15691558

@@ -1586,12 +1575,7 @@ impl Session {
15861575
.unwrap_or(execution_profile.serial_consistency);
15871576
async move {
15881577
connection
1589-
.batch_with_consistency(
1590-
batch,
1591-
values_ref,
1592-
consistency,
1593-
serial_consistency,
1594-
)
1578+
.batch_with_consistency(batch, consistency, serial_consistency)
15951579
.await
15961580
.and_then(QueryResponse::into_non_error_query_response)
15971581
}
@@ -1660,6 +1644,54 @@ impl Session {
16601644
Ok(prepared_batch)
16611645
}
16621646

1647+
async fn last_minute_prepare_batch<'b>(
1648+
&self,
1649+
init_batch: &'b Batch,
1650+
values: impl BatchValues,
1651+
) -> Result<Cow<'b, Batch>, PrepareError> {
1652+
let mut to_prepare = HashSet::<&str>::new();
1653+
1654+
{
1655+
let mut values_iter = values.batch_values_iter();
1656+
for stmt in &init_batch.statements {
1657+
if let BatchStatement::Query(query) = stmt {
1658+
if let Some(false) = values_iter.is_empty_next() {
1659+
to_prepare.insert(&query.contents);
1660+
}
1661+
} else {
1662+
values_iter.skip_next();
1663+
}
1664+
}
1665+
}
1666+
1667+
if to_prepare.is_empty() {
1668+
return Ok(Cow::Borrowed(init_batch));
1669+
}
1670+
1671+
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1672+
1673+
for query in to_prepare {
1674+
let prepared = self.prepare(query).await?;
1675+
prepared_queries.insert(query, prepared);
1676+
}
1677+
1678+
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1679+
for stmt in &init_batch.statements {
1680+
match stmt {
1681+
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1682+
{
1683+
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1684+
None => batch.to_mut().append_statement(query.clone()),
1685+
},
1686+
BatchStatement::PreparedStatement(prepared) => {
1687+
batch.to_mut().append_statement(prepared.clone());
1688+
}
1689+
}
1690+
}
1691+
1692+
Ok(batch)
1693+
}
1694+
16631695
/// Sends `USE <keyspace_name>` request on all connections\
16641696
/// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
16651697
///

scylla/src/network/connection.rs

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::response::{
2727
};
2828
use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
2929
use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
30-
use crate::statement::batch::{Batch, BatchStatement};
30+
use crate::statement::batch::BoundBatch;
3131
use crate::statement::bound::BoundStatement;
3232
use crate::statement::prepared::PreparedStatement;
3333
use crate::statement::unprepared::Statement;
@@ -42,12 +42,10 @@ use scylla_cql::frame::response::result::{ResultMetadata, TableSpec};
4242
use scylla_cql::frame::response::Error;
4343
use scylla_cql::frame::response::{self, error};
4444
use scylla_cql::frame::types::SerialConsistency;
45-
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
46-
use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter;
47-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
45+
use scylla_cql::serialize::row::SerializedValues;
4846
use socket2::{SockRef, TcpKeepalive};
4947
use std::borrow::Cow;
50-
use std::collections::{BTreeSet, HashMap, HashSet};
48+
use std::collections::{BTreeSet, HashMap};
5149
use std::convert::TryFrom;
5250
use std::net::{IpAddr, SocketAddr};
5351
use std::num::NonZeroU64;
@@ -1022,22 +1020,10 @@ impl Connection {
10221020

10231021
pub(crate) async fn batch_with_consistency(
10241022
&self,
1025-
init_batch: &Batch,
1026-
values: impl BatchValues,
1023+
batch: &BoundBatch,
10271024
consistency: Consistency,
10281025
serial_consistency: Option<SerialConsistency>,
10291026
) -> Result<QueryResponse, RequestAttemptError> {
1030-
let batch = self.prepare_batch(init_batch, &values).await?;
1031-
1032-
let contexts = batch.statements.iter().map(|bs| match bs {
1033-
BatchStatement::Query(_) => RowSerializationContext::empty(),
1034-
BatchStatement::PreparedStatement(ps) => {
1035-
RowSerializationContext::from_prepared(ps.get_prepared_metadata())
1036-
}
1037-
});
1038-
1039-
let values = RawBatchValuesAdapter::new(values, contexts);
1040-
10411027
let get_timestamp_from_gen = || {
10421028
self.config
10431029
.timestamp_generator
@@ -1046,13 +1032,13 @@ impl Connection {
10461032
};
10471033
let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen);
10481034

1049-
let batch_frame = batch::Batch {
1050-
statements: Cow::Borrowed(&batch.statements),
1051-
values,
1035+
let batch_frame = batch::BatchV2 {
1036+
statements_and_values: Cow::Borrowed(&batch.buffer),
10521037
batch_type: batch.get_type(),
10531038
consistency,
10541039
serial_consistency,
10551040
timestamp,
1041+
statements_len: batch.statements_len,
10561042
};
10571043

10581044
loop {
@@ -1065,13 +1051,8 @@ impl Connection {
10651051
Response::Error(err) => match err.error {
10661052
DbError::Unprepared { statement_id } => {
10671053
debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1068-
let prepared_statement = batch.statements.iter().find_map(|s| match s {
1069-
BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
1070-
Some(s)
1071-
}
1072-
_ => None,
1073-
});
1074-
if let Some(p) = prepared_statement {
1054+
1055+
if let Some(p) = batch.prepared.get(&statement_id) {
10751056
self.reprepare(p.get_statement(), p).await?;
10761057
continue;
10771058
} else {
@@ -1088,54 +1069,6 @@ impl Connection {
10881069
}
10891070
}
10901071

1091-
async fn prepare_batch<'b>(
1092-
&self,
1093-
init_batch: &'b Batch,
1094-
values: impl BatchValues,
1095-
) -> Result<Cow<'b, Batch>, RequestAttemptError> {
1096-
let mut to_prepare = HashSet::<&str>::new();
1097-
1098-
{
1099-
let mut values_iter = values.batch_values_iter();
1100-
for stmt in &init_batch.statements {
1101-
if let BatchStatement::Query(query) = stmt {
1102-
if let Some(false) = values_iter.is_empty_next() {
1103-
to_prepare.insert(&query.contents);
1104-
}
1105-
} else {
1106-
values_iter.skip_next();
1107-
}
1108-
}
1109-
}
1110-
1111-
if to_prepare.is_empty() {
1112-
return Ok(Cow::Borrowed(init_batch));
1113-
}
1114-
1115-
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1116-
1117-
for query in &to_prepare {
1118-
let prepared = self.prepare(&Statement::new(query.to_string())).await?;
1119-
prepared_queries.insert(query, prepared);
1120-
}
1121-
1122-
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1123-
for stmt in &init_batch.statements {
1124-
match stmt {
1125-
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1126-
{
1127-
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1128-
None => batch.to_mut().append_statement(query.clone()),
1129-
},
1130-
BatchStatement::PreparedStatement(prepared) => {
1131-
batch.to_mut().append_statement(prepared.clone());
1132-
}
1133-
}
1134-
}
1135-
1136-
Ok(batch)
1137-
}
1138-
11391072
pub(super) async fn use_keyspace(
11401073
&self,
11411074
keyspace_name: &VerifiedKeyspaceName,

0 commit comments

Comments
 (0)