Skip to content

Commit 4bd12ad

Browse files
committed
Add an internal only BoundBatch
1 parent 6bb3f88 commit 4bd12ad

File tree

6 files changed

+308
-257
lines changed

6 files changed

+308
-257
lines changed

scylla-cql/src/frame/request/batch.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,52 @@ where
3535
pub values: Values,
3636
}
3737

38+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
39+
pub struct BatchV2<'b> {
40+
pub statements_and_values: Cow<'b, [u8]>,
41+
pub batch_type: BatchType,
42+
pub consistency: types::Consistency,
43+
pub serial_consistency: Option<types::SerialConsistency>,
44+
pub timestamp: Option<i64>,
45+
pub statements_len: u16,
46+
}
47+
48+
impl SerializableRequest for BatchV2<'_> {
49+
const OPCODE: RequestOpcode = RequestOpcode::Batch;
50+
51+
fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
52+
// Serializing type of batch
53+
buf.put_u8(self.batch_type as u8);
54+
55+
// Serializing queries
56+
types::write_short(self.statements_len, buf);
57+
buf.extend_from_slice(&self.statements_and_values);
58+
59+
// Serializing consistency
60+
types::write_consistency(self.consistency, buf);
61+
62+
// Serializing flags
63+
let mut flags = 0;
64+
if self.serial_consistency.is_some() {
65+
flags |= FLAG_WITH_SERIAL_CONSISTENCY;
66+
}
67+
if self.timestamp.is_some() {
68+
flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
69+
}
70+
71+
buf.put_u8(flags);
72+
73+
if let Some(serial_consistency) = self.serial_consistency {
74+
types::write_serial_consistency(serial_consistency, buf);
75+
}
76+
if let Some(timestamp) = self.timestamp {
77+
types::write_long(timestamp, buf);
78+
}
79+
80+
Ok(())
81+
}
82+
}
83+
3884
/// The type of a batch.
3985
#[derive(Clone, Copy)]
4086
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
@@ -208,7 +254,7 @@ impl BatchStatement<'_> {
208254
}
209255

210256
impl BatchStatement<'_> {
211-
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
257+
pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
212258
match self {
213259
Self::Query { text } => {
214260
buf.put_u8(0);

scylla/src/client/session.rs

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::cluster::node::CloudEndpoint;
1212
use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
1313
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
1414
use crate::errors::{
15-
BadQuery, ExecutionError, MetadataError, NewSessionError, PagerExecutionError, PrepareError,
15+
ExecutionError, MetadataError, NewSessionError, PagerExecutionError, PrepareError,
1616
RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError,
1717
};
1818
use crate::frame::response::result;
@@ -35,7 +35,7 @@ use crate::response::{
3535
};
3636
use crate::routing::partitioner::PartitionerName;
3737
use crate::routing::{Shard, ShardAwarePortRange};
38-
use crate::statement::batch::batch_values;
38+
use crate::statement::batch::BoundBatch;
3939
use crate::statement::batch::{Batch, BatchStatement};
4040
use crate::statement::bound::BoundStatement;
4141
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
@@ -46,9 +46,10 @@ use futures::future::join_all;
4646
use futures::future::try_join_all;
4747
use itertools::Itertools;
4848
use scylla_cql::frame::response::NonErrorResponse;
49-
use scylla_cql::serialize::batch::BatchValues;
49+
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
5050
use scylla_cql::serialize::row::SerializeRow;
51-
use std::borrow::Borrow;
51+
use std::borrow::{Borrow, Cow};
52+
use std::collections::{HashMap, HashSet};
5253
use std::future::Future;
5354
use std::net::{IpAddr, SocketAddr};
5455
use std::num::NonZeroU32;
@@ -806,7 +807,10 @@ impl Session {
806807
batch: &Batch,
807808
values: impl BatchValues,
808809
) -> Result<QueryResult, ExecutionError> {
809-
self.do_batch(batch, values).await
810+
let batch = self.last_minute_prepare_batch(batch, &values).await?;
811+
let batch = BoundBatch::from_batch(batch.as_ref(), values)?;
812+
813+
self.do_batch(&batch).await
810814
}
811815
}
812816

@@ -1445,22 +1449,9 @@ impl Session {
14451449
.map_err(PagerExecutionError::NextPageError)
14461450
}
14471451

1448-
async fn do_batch(
1449-
&self,
1450-
batch: &Batch,
1451-
values: impl BatchValues,
1452-
) -> Result<QueryResult, ExecutionError> {
1452+
async fn do_batch(&self, batch: &BoundBatch) -> Result<QueryResult, ExecutionError> {
14531453
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
14541454
// If users batch statements by shard, they will be rewarded with full shard awareness
1455-
1456-
// check to ensure that we don't send a batch statement with more than u16::MAX queries
1457-
let batch_statements_length = batch.statements.len();
1458-
if batch_statements_length > u16::MAX as usize {
1459-
return Err(ExecutionError::BadQuery(
1460-
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
1461-
));
1462-
}
1463-
14641455
let execution_profile = batch
14651456
.get_execution_profile_handle()
14661457
.unwrap_or_else(|| self.get_default_execution_profile_handle())
@@ -1476,22 +1467,17 @@ impl Session {
14761467
.serial_consistency
14771468
.unwrap_or(execution_profile.serial_consistency);
14781469

1479-
let (first_value_token, values) =
1480-
batch_values::peek_first_token(values, batch.statements.first())?;
1481-
let values_ref = &values;
1482-
1483-
let table_spec =
1484-
if let Some(BatchStatement::PreparedStatement(ps)) = batch.statements.first() {
1485-
ps.get_table_spec()
1486-
} else {
1487-
None
1488-
};
1470+
let (table, token) = batch
1471+
.first_prepared
1472+
.as_ref()
1473+
.and_then(|(ps, token)| ps.get_table_spec().map(|table| (table, *token)))
1474+
.unzip();
14891475

14901476
let statement_info = RoutingInfo {
14911477
consistency,
14921478
serial_consistency,
1493-
token: first_value_token,
1494-
table: table_spec,
1479+
token,
1480+
table,
14951481
is_confirmed_lwt: false,
14961482
};
14971483

@@ -1511,12 +1497,7 @@ impl Session {
15111497
.unwrap_or(execution_profile.serial_consistency);
15121498
async move {
15131499
connection
1514-
.batch_with_consistency(
1515-
batch,
1516-
values_ref,
1517-
consistency,
1518-
serial_consistency,
1519-
)
1500+
.batch_with_consistency(batch, consistency, serial_consistency)
15201501
.await
15211502
.and_then(QueryResponse::into_non_error_query_response)
15221503
}
@@ -1585,6 +1566,54 @@ impl Session {
15851566
Ok(prepared_batch)
15861567
}
15871568

1569+
async fn last_minute_prepare_batch<'b>(
1570+
&self,
1571+
init_batch: &'b Batch,
1572+
values: impl BatchValues,
1573+
) -> Result<Cow<'b, Batch>, PrepareError> {
1574+
let mut to_prepare = HashSet::<&str>::new();
1575+
1576+
{
1577+
let mut values_iter = values.batch_values_iter();
1578+
for stmt in &init_batch.statements {
1579+
if let BatchStatement::Query(query) = stmt {
1580+
if let Some(false) = values_iter.is_empty_next() {
1581+
to_prepare.insert(&query.contents);
1582+
}
1583+
} else {
1584+
values_iter.skip_next();
1585+
}
1586+
}
1587+
}
1588+
1589+
if to_prepare.is_empty() {
1590+
return Ok(Cow::Borrowed(init_batch));
1591+
}
1592+
1593+
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1594+
1595+
for query in to_prepare {
1596+
let prepared = self.prepare(query).await?;
1597+
prepared_queries.insert(query, prepared);
1598+
}
1599+
1600+
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1601+
for stmt in &init_batch.statements {
1602+
match stmt {
1603+
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1604+
{
1605+
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1606+
None => batch.to_mut().append_statement(query.clone()),
1607+
},
1608+
BatchStatement::PreparedStatement(prepared) => {
1609+
batch.to_mut().append_statement(prepared.clone());
1610+
}
1611+
}
1612+
}
1613+
1614+
Ok(batch)
1615+
}
1616+
15881617
/// Sends `USE <keyspace_name>` request on all connections\
15891618
/// This allows to write `SELECT * FROM table` instead of `SELECT * FROM keyspace.table`\
15901619
///

scylla/src/network/connection.rs

Lines changed: 10 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::response::{
2828
};
2929
use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
3030
use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
31-
use crate::statement::batch::{Batch, BatchStatement};
31+
use crate::statement::batch::BoundBatch;
3232
use crate::statement::bound::BoundStatement;
3333
use crate::statement::prepared::PreparedStatement;
3434
use crate::statement::unprepared::Statement;
@@ -43,12 +43,10 @@ use scylla_cql::frame::response::result::{ResultMetadata, TableSpec};
4343
use scylla_cql::frame::response::Error;
4444
use scylla_cql::frame::response::{self, error};
4545
use scylla_cql::frame::types::SerialConsistency;
46-
use scylla_cql::serialize::batch::{BatchValues, BatchValuesIterator};
47-
use scylla_cql::serialize::raw_batch::RawBatchValuesAdapter;
48-
use scylla_cql::serialize::row::{RowSerializationContext, SerializedValues};
46+
use scylla_cql::serialize::row::SerializedValues;
4947
use socket2::{SockRef, TcpKeepalive};
5048
use std::borrow::Cow;
51-
use std::collections::{BTreeSet, HashMap, HashSet};
49+
use std::collections::{BTreeSet, HashMap};
5250
use std::convert::TryFrom;
5351
use std::net::{IpAddr, SocketAddr};
5452
use std::num::NonZeroU64;
@@ -1079,12 +1077,10 @@ impl Connection {
10791077
#[allow(dead_code)]
10801078
pub(crate) async fn batch(
10811079
&self,
1082-
batch: &Batch,
1083-
values: impl BatchValues,
1080+
batch: &BoundBatch,
10841081
) -> Result<QueryResult, RequestAttemptError> {
10851082
self.batch_with_consistency(
10861083
batch,
1087-
values,
10881084
batch
10891085
.config
10901086
.determine_consistency(self.config.default_consistency),
@@ -1096,22 +1092,10 @@ impl Connection {
10961092

10971093
pub(crate) async fn batch_with_consistency(
10981094
&self,
1099-
init_batch: &Batch,
1100-
values: impl BatchValues,
1095+
batch: &BoundBatch,
11011096
consistency: Consistency,
11021097
serial_consistency: Option<SerialConsistency>,
11031098
) -> Result<QueryResponse, RequestAttemptError> {
1104-
let batch = self.prepare_batch(init_batch, &values).await?;
1105-
1106-
let contexts = batch.statements.iter().map(|bs| match bs {
1107-
BatchStatement::Query(_) => RowSerializationContext::empty(),
1108-
BatchStatement::PreparedStatement(ps) => {
1109-
RowSerializationContext::from_prepared(ps.get_prepared_metadata())
1110-
}
1111-
});
1112-
1113-
let values = RawBatchValuesAdapter::new(values, contexts);
1114-
11151099
let get_timestamp_from_gen = || {
11161100
self.config
11171101
.timestamp_generator
@@ -1120,13 +1104,13 @@ impl Connection {
11201104
};
11211105
let timestamp = batch.get_timestamp().or_else(get_timestamp_from_gen);
11221106

1123-
let batch_frame = batch::Batch {
1124-
statements: Cow::Borrowed(&batch.statements),
1125-
values,
1107+
let batch_frame = batch::BatchV2 {
1108+
statements_and_values: Cow::Borrowed(&batch.buffer),
11261109
batch_type: batch.get_type(),
11271110
consistency,
11281111
serial_consistency,
11291112
timestamp,
1113+
statements_len: batch.statements_len,
11301114
};
11311115

11321116
loop {
@@ -1139,13 +1123,8 @@ impl Connection {
11391123
Response::Error(err) => match err.error {
11401124
DbError::Unprepared { statement_id } => {
11411125
debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
1142-
let prepared_statement = batch.statements.iter().find_map(|s| match s {
1143-
BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
1144-
Some(s)
1145-
}
1146-
_ => None,
1147-
});
1148-
if let Some(p) = prepared_statement {
1126+
1127+
if let Some(p) = batch.prepared.get(&statement_id) {
11491128
self.reprepare(p.get_statement(), p).await?;
11501129
continue;
11511130
} else {
@@ -1162,54 +1141,6 @@ impl Connection {
11621141
}
11631142
}
11641143

1165-
async fn prepare_batch<'b>(
1166-
&self,
1167-
init_batch: &'b Batch,
1168-
values: impl BatchValues,
1169-
) -> Result<Cow<'b, Batch>, RequestAttemptError> {
1170-
let mut to_prepare = HashSet::<&str>::new();
1171-
1172-
{
1173-
let mut values_iter = values.batch_values_iter();
1174-
for stmt in &init_batch.statements {
1175-
if let BatchStatement::Query(query) = stmt {
1176-
if let Some(false) = values_iter.is_empty_next() {
1177-
to_prepare.insert(&query.contents);
1178-
}
1179-
} else {
1180-
values_iter.skip_next();
1181-
}
1182-
}
1183-
}
1184-
1185-
if to_prepare.is_empty() {
1186-
return Ok(Cow::Borrowed(init_batch));
1187-
}
1188-
1189-
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
1190-
1191-
for query in &to_prepare {
1192-
let prepared = self.prepare(&Statement::new(query.to_string())).await?;
1193-
prepared_queries.insert(query, prepared);
1194-
}
1195-
1196-
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
1197-
for stmt in &init_batch.statements {
1198-
match stmt {
1199-
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
1200-
{
1201-
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
1202-
None => batch.to_mut().append_statement(query.clone()),
1203-
},
1204-
BatchStatement::PreparedStatement(prepared) => {
1205-
batch.to_mut().append_statement(prepared.clone());
1206-
}
1207-
}
1208-
}
1209-
1210-
Ok(batch)
1211-
}
1212-
12131144
pub(super) async fn use_keyspace(
12141145
&self,
12151146
keyspace_name: &VerifiedKeyspaceName,

0 commit comments

Comments
 (0)