Skip to content

Commit 1aaadba

Browse files
committed
switch scylla-cql's request::Batch to use new BatchV2 version
this is the version that the top crate (scylla) will use to send batches
1 parent 4bd12ad commit 1aaadba

File tree

4 files changed

+163
-30
lines changed

4 files changed

+163
-30
lines changed

scylla-cql/src/frame/request/batch.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,65 @@ impl SerializableRequest for BatchV2<'_> {
8181
}
8282
}
8383

84+
impl DeserializableRequest for BatchV2<'static> {
85+
fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
86+
let batch_type = buf.get_u8().try_into()?;
87+
let statements_len = types::read_short(buf)?;
88+
89+
let statements_and_values = (0..statements_len).try_fold(
90+
// technically allocating 3-13 bytes too many but that's OK
91+
Vec::with_capacity(buf.len()),
92+
|mut statements_and_values, _| {
93+
BatchStatement::deserialize_to_buffer(buf, &mut statements_and_values)?;
94+
// As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
95+
let values = SerializedValues::new_from_frame(buf)?;
96+
statements_and_values.extend_from_slice(&values.element_count().to_be_bytes());
97+
statements_and_values.extend_from_slice(values.get_contents());
98+
99+
Result::<_, RequestDeserializationError>::Ok(statements_and_values)
100+
},
101+
)?;
102+
103+
let consistency = types::read_consistency(buf)?;
104+
105+
let flags = buf.get_u8();
106+
let unknown_flags = flags & (!ALL_FLAGS);
107+
if unknown_flags != 0 {
108+
return Err(RequestDeserializationError::UnknownFlags {
109+
flags: unknown_flags,
110+
});
111+
}
112+
let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
113+
let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
114+
115+
let serial_consistency = serial_consistency_flag
116+
.then(|| types::read_consistency(buf))
117+
.transpose()?
118+
.map(
119+
|consistency| match SerialConsistency::try_from(consistency) {
120+
Ok(serial_consistency) => Ok(serial_consistency),
121+
Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
122+
consistency,
123+
)),
124+
},
125+
)
126+
.transpose()?;
127+
128+
let timestamp = default_timestamp_flag
129+
.then(|| types::read_long(buf))
130+
.transpose()?;
131+
132+
Ok(Self {
133+
batch_type,
134+
consistency,
135+
serial_consistency,
136+
timestamp,
137+
statements_len,
138+
statements_and_values: Cow::Owned(statements_and_values),
139+
})
140+
}
141+
}
142+
84143
/// The type of a batch.
85144
#[derive(Clone, Copy)]
86145
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
@@ -251,6 +310,29 @@ impl BatchStatement<'_> {
251310
)),
252311
}
253312
}
313+
314+
fn deserialize_to_buffer(
315+
input: &mut &[u8],
316+
out: &mut Vec<u8>,
317+
) -> Result<(), RequestDeserializationError> {
318+
match input.get_u8() {
319+
0 => {
320+
out.put_u8(0);
321+
types::read_long_string_to_buff(input, out)?;
322+
323+
Ok(())
324+
}
325+
1 => {
326+
out.put_u8(1);
327+
types::read_short_bytes_to_buffer(input, out)?;
328+
329+
Ok(())
330+
}
331+
kind => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
332+
kind,
333+
)),
334+
}
335+
}
254336
}
255337

256338
impl BatchStatement<'_> {

scylla-cql/src/frame/request/mod.rs

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use bytes::Bytes;
1616

1717
pub use auth_response::AuthResponse;
1818
pub use batch::Batch;
19+
pub use batch::BatchV2;
1920
pub use execute::Execute;
2021
pub use options::Options;
2122
pub use prepare::Prepare;
@@ -138,6 +139,7 @@ pub enum Request<'r> {
138139
Query(Query<'r>),
139140
Execute(Execute<'r>),
140141
Batch(Batch<'r, BatchStatement<'r>, Vec<SerializedValues>>),
142+
BatchV2(BatchV2<'r>),
141143
}
142144

143145
impl Request<'_> {
@@ -148,7 +150,7 @@ impl Request<'_> {
148150
match opcode {
149151
RequestOpcode::Query => Query::deserialize(buf).map(Self::Query),
150152
RequestOpcode::Execute => Execute::deserialize(buf).map(Self::Execute),
151-
RequestOpcode::Batch => Batch::deserialize(buf).map(Self::Batch),
153+
RequestOpcode::Batch => BatchV2::deserialize(buf).map(Self::BatchV2),
152154
_ => unimplemented!(
153155
"Deserialization of opcode {:?} is not yet supported",
154156
opcode
@@ -162,6 +164,7 @@ impl Request<'_> {
162164
Request::Query(q) => Some(q.parameters.consistency),
163165
Request::Execute(e) => Some(e.parameters.consistency),
164166
Request::Batch(b) => Some(b.consistency),
167+
Request::BatchV2(b) => Some(b.consistency),
165168
#[allow(unreachable_patterns)] // until other opcodes are supported
166169
_ => None,
167170
}
@@ -173,6 +176,7 @@ impl Request<'_> {
173176
Request::Query(q) => Some(q.parameters.serial_consistency),
174177
Request::Execute(e) => Some(e.parameters.serial_consistency),
175178
Request::Batch(b) => Some(b.serial_consistency),
179+
Request::BatchV2(b) => Some(b.serial_consistency),
176180
#[allow(unreachable_patterns)] // until other opcodes are supported
177181
_ => None,
178182
}
@@ -181,15 +185,15 @@ impl Request<'_> {
181185

182186
#[cfg(test)]
183187
mod tests {
184-
use std::{borrow::Cow, ops::Deref};
188+
use std::borrow::Cow;
185189

186190
use bytes::Bytes;
187191

188192
use crate::serialize::row::SerializedValues;
189193
use crate::{
190194
frame::{
191195
request::{
192-
batch::{Batch, BatchStatement, BatchType},
196+
batch::{BatchStatement, BatchType, BatchV2},
193197
execute::Execute,
194198
query::{Query, QueryParameters},
195199
DeserializableRequest, SerializableRequest,
@@ -261,32 +265,39 @@ mod tests {
261265
}
262266

263267
// Batch
264-
let statements = vec![
265-
BatchStatement::Query {
266-
text: query.contents,
267-
},
268-
BatchStatement::Prepared {
269-
id: Cow::Borrowed(&execute.id),
270-
},
271-
];
272-
let batch = Batch {
273-
statements: Cow::Owned(statements),
268+
// Not execute's values, because named values are not supported in batches.
269+
let mut statements_and_values = vec![];
270+
BatchStatement::Query {
271+
text: query.contents,
272+
}
273+
.serialize(&mut statements_and_values)
274+
.unwrap();
275+
statements_and_values
276+
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
277+
statements_and_values.extend_from_slice(query.parameters.values.get_contents());
278+
279+
BatchStatement::Prepared {
280+
id: Cow::Borrowed(&execute.id),
281+
}
282+
.serialize(&mut statements_and_values)
283+
.unwrap();
284+
statements_and_values
285+
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
286+
statements_and_values.extend_from_slice(query.parameters.values.get_contents());
287+
288+
let batch = BatchV2 {
289+
statements_and_values: Cow::Owned(statements_and_values),
274290
batch_type: BatchType::Logged,
275291
consistency: Consistency::EachQuorum,
276292
serial_consistency: Some(SerialConsistency::LocalSerial),
277293
timestamp: Some(32432),
278-
279-
// Not execute's values, because named values are not supported in batches.
280-
values: vec![
281-
query.parameters.values.deref().clone(),
282-
query.parameters.values.deref().clone(),
283-
],
294+
statements_len: 2,
284295
};
285296
{
286297
let mut buf = Vec::new();
287298
batch.serialize(&mut buf).unwrap();
288299

289-
let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
300+
let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap();
290301
assert_eq!(&batch_deserialized, &batch);
291302
}
292303
}
@@ -341,24 +352,30 @@ mod tests {
341352
}
342353

343354
// Batch
344-
let statements = vec![BatchStatement::Query {
355+
let mut statements_and_values = vec![];
356+
BatchStatement::Query {
345357
text: query.contents,
346-
}];
347-
let batch = Batch {
348-
statements: Cow::Owned(statements),
358+
}
359+
.serialize(&mut statements_and_values)
360+
.unwrap();
361+
statements_and_values
362+
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
363+
statements_and_values.extend_from_slice(query.parameters.values.get_contents());
364+
365+
let batch = BatchV2 {
349366
batch_type: BatchType::Logged,
350367
consistency: Consistency::EachQuorum,
351368
serial_consistency: None,
352369
timestamp: None,
353-
354-
values: vec![query.parameters.values.deref().clone()],
370+
statements_and_values: Cow::Owned(statements_and_values),
371+
statements_len: 1,
355372
};
356373
{
357374
let mut buf = Vec::new();
358375
batch.serialize(&mut buf).unwrap();
359376

360377
// Sanity check: batch deserializes to the equivalent.
361-
let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
378+
let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap();
362379
assert_eq!(batch, batch_deserialized);
363380

364381
// Now modify flags by adding an unknown one.
@@ -370,7 +387,7 @@ mod tests {
370387

371388
// Unknown flag should lead to frame rejection, as unknown flags can be new protocol extensions
372389
// leading to different semantics.
373-
let _parse_error = Batch::deserialize(&mut &buf[..]).unwrap_err();
390+
let _parse_error = BatchV2::deserialize(&mut &buf[..]).unwrap_err();
374391
}
375392
}
376393
}

scylla-cql/src/frame/types.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,17 @@ pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDese
288288
Ok(v)
289289
}
290290

291+
pub fn read_short_bytes_to_buffer(
292+
input: &mut &[u8],
293+
out: &mut impl BufMut,
294+
) -> Result<(), LowLevelDeserializationError> {
295+
let len = read_short(input)?;
296+
let v = read_raw_bytes(len.into(), input)?;
297+
write_short(len, out);
298+
out.put_slice(v);
299+
Ok(())
300+
}
301+
291302
pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> {
292303
write_int_length(v.len(), buf)?;
293304
buf.put_slice(v);
@@ -394,6 +405,23 @@ pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), std::num:
394405
Ok(())
395406
}
396407

408+
pub(crate) fn read_long_string_to_buff(
409+
input: &mut &[u8],
410+
out: &mut impl BufMut,
411+
) -> Result<(), LowLevelDeserializationError> {
412+
let len = read_int(input)?;
413+
let raw = read_raw_bytes(len.try_into()?, input)?;
414+
415+
// verify it is a valid string but ignore the out string; we already have the raw bytes
416+
let _ = str::from_utf8(raw)?;
417+
418+
// now write it out
419+
write_int(len, out);
420+
out.put_slice(raw);
421+
422+
Ok(())
423+
}
424+
397425
#[test]
398426
fn type_long_string() {
399427
let vals = [String::from(""), String::from("hello, world!")];

scylla/src/statement/prepared.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,14 +448,20 @@ impl PreparedStatement {
448448
self.config.execution_profile_handle.as_ref()
449449
}
450450

451-
pub(crate) fn bind(
451+
/// Binds values with a reference to a prepared statement
452+
///
453+
/// This method will serialize the values and thus type erase them on return
454+
pub fn bind(
452455
&self,
453456
values: &impl SerializeRow,
454457
) -> Result<BoundStatement<'_>, SerializationError> {
455458
BoundStatement::new(Cow::Borrowed(self), values)
456459
}
457460

458-
pub(crate) fn into_bind(
461+
/// Binds values with an owned prepared statement
462+
///
463+
/// This method will serialize the values and thus type erase them on return
464+
pub fn into_bind(
459465
self,
460466
values: &impl SerializeRow,
461467
) -> Result<BoundStatement<'static>, SerializationError> {

0 commit comments

Comments
 (0)