Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion scylla-cql/src/frame/request/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,111 @@ where
pub values: Values,
}

#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
pub struct BatchV2<'b> {
pub statements_and_values: Cow<'b, [u8]>,
pub batch_type: BatchType,
pub consistency: types::Consistency,
pub serial_consistency: Option<types::SerialConsistency>,
pub timestamp: Option<i64>,
pub statements_len: u16,
}

impl SerializableRequest for BatchV2<'_> {
const OPCODE: RequestOpcode = RequestOpcode::Batch;

fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
// Serializing type of batch
buf.put_u8(self.batch_type as u8);

// Serializing queries
types::write_short(self.statements_len, buf);
buf.extend_from_slice(&self.statements_and_values);

// Serializing consistency
types::write_consistency(self.consistency, buf);

// Serializing flags
let mut flags = 0;
if self.serial_consistency.is_some() {
flags |= FLAG_WITH_SERIAL_CONSISTENCY;
}
if self.timestamp.is_some() {
flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
}

buf.put_u8(flags);

if let Some(serial_consistency) = self.serial_consistency {
types::write_serial_consistency(serial_consistency, buf);
}
if let Some(timestamp) = self.timestamp {
types::write_long(timestamp, buf);
}

Ok(())
}
}

impl DeserializableRequest for BatchV2<'static> {
fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
let batch_type = buf.get_u8().try_into()?;
let statements_len = types::read_short(buf)?;

let statements_and_values = (0..statements_len).try_fold(
// technically allocating 3-13 bytes too many but that's OK
Vec::with_capacity(buf.len()),
|mut statements_and_values, _| {
BatchStatement::deserialize_to_buffer(buf, &mut statements_and_values)?;
// As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
let values = SerializedValues::new_from_frame(buf)?;
statements_and_values.extend_from_slice(&values.element_count().to_be_bytes());
statements_and_values.extend_from_slice(values.get_contents());

Result::<_, RequestDeserializationError>::Ok(statements_and_values)
},
)?;

let consistency = types::read_consistency(buf)?;

let flags = buf.get_u8();
let unknown_flags = flags & (!ALL_FLAGS);
if unknown_flags != 0 {
return Err(RequestDeserializationError::UnknownFlags {
flags: unknown_flags,
});
}
let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;

let serial_consistency = serial_consistency_flag
.then(|| types::read_consistency(buf))
.transpose()?
.map(
|consistency| match SerialConsistency::try_from(consistency) {
Ok(serial_consistency) => Ok(serial_consistency),
Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
consistency,
)),
},
)
.transpose()?;

let timestamp = default_timestamp_flag
.then(|| types::read_long(buf))
.transpose()?;

Ok(Self {
batch_type,
consistency,
serial_consistency,
timestamp,
statements_len,
statements_and_values: Cow::Owned(statements_and_values),
})
}
}

/// The type of a batch.
#[derive(Clone, Copy)]
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
Expand Down Expand Up @@ -205,10 +310,33 @@ impl BatchStatement<'_> {
)),
}
}

fn deserialize_to_buffer(
input: &mut &[u8],
out: &mut Vec<u8>,
) -> Result<(), RequestDeserializationError> {
match input.get_u8() {
0 => {
out.put_u8(0);
types::read_long_string_to_buff(input, out)?;

Ok(())
}
1 => {
out.put_u8(1);
types::read_short_bytes_to_buffer(input, out)?;

Ok(())
}
kind => Err(RequestDeserializationError::UnexpectedBatchStatementKind(
kind,
)),
}
}
}

impl BatchStatement<'_> {
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> {
match self {
Self::Query { text } => {
buf.put_u8(0);
Expand Down
73 changes: 45 additions & 28 deletions scylla-cql/src/frame/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use bytes::Bytes;

pub use auth_response::AuthResponse;
pub use batch::Batch;
pub use batch::BatchV2;
pub use execute::Execute;
pub use options::Options;
pub use prepare::Prepare;
Expand Down Expand Up @@ -138,6 +139,7 @@ pub enum Request<'r> {
Query(Query<'r>),
Execute(Execute<'r>),
Batch(Batch<'r, BatchStatement<'r>, Vec<SerializedValues>>),
BatchV2(BatchV2<'r>),
}

impl Request<'_> {
Expand All @@ -148,7 +150,7 @@ impl Request<'_> {
match opcode {
RequestOpcode::Query => Query::deserialize(buf).map(Self::Query),
RequestOpcode::Execute => Execute::deserialize(buf).map(Self::Execute),
RequestOpcode::Batch => Batch::deserialize(buf).map(Self::Batch),
RequestOpcode::Batch => BatchV2::deserialize(buf).map(Self::BatchV2),
_ => unimplemented!(
"Deserialization of opcode {:?} is not yet supported",
opcode
Expand All @@ -162,6 +164,7 @@ impl Request<'_> {
Request::Query(q) => Some(q.parameters.consistency),
Request::Execute(e) => Some(e.parameters.consistency),
Request::Batch(b) => Some(b.consistency),
Request::BatchV2(b) => Some(b.consistency),
#[expect(unreachable_patterns)] // until other opcodes are supported
_ => None,
}
Expand All @@ -173,6 +176,7 @@ impl Request<'_> {
Request::Query(q) => Some(q.parameters.serial_consistency),
Request::Execute(e) => Some(e.parameters.serial_consistency),
Request::Batch(b) => Some(b.serial_consistency),
Request::BatchV2(b) => Some(b.serial_consistency),
#[expect(unreachable_patterns)] // until other opcodes are supported
_ => None,
}
Expand All @@ -181,15 +185,15 @@ impl Request<'_> {

#[cfg(test)]
mod tests {
use std::{borrow::Cow, ops::Deref};
use std::borrow::Cow;

use bytes::Bytes;

use crate::serialize::row::SerializedValues;
use crate::{
frame::{
request::{
batch::{Batch, BatchStatement, BatchType},
batch::{BatchStatement, BatchType, BatchV2},
execute::Execute,
query::{Query, QueryParameters},
DeserializableRequest, SerializableRequest,
Expand Down Expand Up @@ -261,32 +265,39 @@ mod tests {
}

// Batch
let statements = vec![
BatchStatement::Query {
text: query.contents,
},
BatchStatement::Prepared {
id: Cow::Borrowed(&execute.id),
},
];
let batch = Batch {
statements: Cow::Owned(statements),
// Not execute's values, because named values are not supported in batches.
let mut statements_and_values = vec![];
BatchStatement::Query {
text: query.contents,
}
.serialize(&mut statements_and_values)
.unwrap();
statements_and_values
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
statements_and_values.extend_from_slice(query.parameters.values.get_contents());

BatchStatement::Prepared {
id: Cow::Borrowed(&execute.id),
}
.serialize(&mut statements_and_values)
.unwrap();
statements_and_values
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
statements_and_values.extend_from_slice(query.parameters.values.get_contents());

let batch = BatchV2 {
statements_and_values: Cow::Owned(statements_and_values),
batch_type: BatchType::Logged,
consistency: Consistency::EachQuorum,
serial_consistency: Some(SerialConsistency::LocalSerial),
timestamp: Some(32432),

// Not execute's values, because named values are not supported in batches.
values: vec![
query.parameters.values.deref().clone(),
query.parameters.values.deref().clone(),
],
statements_len: 2,
};
{
let mut buf = Vec::new();
batch.serialize(&mut buf).unwrap();

let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap();
assert_eq!(&batch_deserialized, &batch);
}
}
Expand Down Expand Up @@ -341,24 +352,30 @@ mod tests {
}

// Batch
let statements = vec![BatchStatement::Query {
let mut statements_and_values = vec![];
BatchStatement::Query {
text: query.contents,
}];
let batch = Batch {
statements: Cow::Owned(statements),
}
.serialize(&mut statements_and_values)
.unwrap();
statements_and_values
.extend_from_slice(&query.parameters.values.element_count().to_be_bytes());
statements_and_values.extend_from_slice(query.parameters.values.get_contents());

let batch = BatchV2 {
batch_type: BatchType::Logged,
consistency: Consistency::EachQuorum,
serial_consistency: None,
timestamp: None,

values: vec![query.parameters.values.deref().clone()],
statements_and_values: Cow::Owned(statements_and_values),
statements_len: 1,
};
{
let mut buf = Vec::new();
batch.serialize(&mut buf).unwrap();

// Sanity check: batch deserializes to the equivalent.
let batch_deserialized = Batch::deserialize(&mut &buf[..]).unwrap();
let batch_deserialized = BatchV2::deserialize(&mut &buf[..]).unwrap();
assert_eq!(batch, batch_deserialized);

// Now modify flags by adding an unknown one.
Expand All @@ -370,7 +387,7 @@ mod tests {

// Unknown flag should lead to frame rejection, as unknown flags can be new protocol extensions
// leading to different semantics.
let _parse_error = Batch::deserialize(&mut &buf[..]).unwrap_err();
let _parse_error = BatchV2::deserialize(&mut &buf[..]).unwrap_err();
}
}
}
28 changes: 28 additions & 0 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDese
Ok(v)
}

pub fn read_short_bytes_to_buffer(
input: &mut &[u8],
out: &mut impl BufMut,
) -> Result<(), LowLevelDeserializationError> {
let len = read_short(input)?;
let v = read_raw_bytes(len.into(), input)?;
write_short(len, out);
out.put_slice(v);
Ok(())
}

pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> {
write_int_length(v.len(), buf)?;
buf.put_slice(v);
Expand Down Expand Up @@ -394,6 +405,23 @@ pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), std::num:
Ok(())
}

pub(crate) fn read_long_string_to_buff(
input: &mut &[u8],
out: &mut impl BufMut,
) -> Result<(), LowLevelDeserializationError> {
let len = read_int(input)?;
let raw = read_raw_bytes(len.try_into()?, input)?;

// verify it is a valid string but ignore the out string; we already have the raw bytes
let _ = str::from_utf8(raw)?;

// now write it out
write_int(len, out);
out.put_slice(raw);

Ok(())
}

#[test]
fn type_long_string() {
let vals = [String::from(""), String::from("hello, world!")];
Expand Down
Loading
Loading