1- use crate :: frame:: { frame_errors:: ParseError , value:: BatchValuesIterator } ;
2- use bytes:: { BufMut , Bytes } ;
3- use std:: convert:: TryInto ;
1+ use bytes:: { Buf , BufMut } ;
2+ use std:: { borrow:: Cow , convert:: TryInto } ;
43
54use crate :: frame:: {
6- request:: { Request , RequestOpcode } ,
5+ frame_errors:: ParseError ,
6+ request:: { RequestOpcode , SerializableRequest } ,
77 types,
8- value:: BatchValues ,
8+ value:: { BatchValues , BatchValuesIterator , SerializedValues } ,
99} ;
1010
11+ use super :: DeserializableRequest ;
12+
1113// Batch flags
1214const FLAG_WITH_SERIAL_CONSISTENCY : u8 = 0x10 ;
1315const FLAG_WITH_DEFAULT_TIMESTAMP : u8 = 0x20 ;
16+ const ALL_FLAGS : u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP ;
1417
15- pub struct Batch < ' a , StatementsIter , Values >
18+ #[ cfg_attr( test, derive( Debug , PartialEq , Eq ) ) ]
19+ pub struct Batch < ' b , Statement , Values >
1620where
17- StatementsIter : Iterator < Item = BatchStatement < ' a > > + Clone ,
21+ BatchStatement < ' b > : From < & ' b Statement > ,
22+ Statement : Clone ,
1823 Values : BatchValues ,
1924{
20- pub statements : StatementsIter ,
21- pub statements_count : usize ,
25+ pub statements : Cow < ' b , [ Statement ] > ,
2226 pub batch_type : BatchType ,
2327 pub consistency : types:: Consistency ,
2428 pub serial_consistency : Option < types:: SerialConsistency > ,
@@ -28,21 +32,46 @@ where
2832
2933/// The type of a batch.
3034#[ derive( Clone , Copy ) ]
35+ #[ cfg_attr( test, derive( Debug , PartialEq , Eq ) ) ]
3136pub enum BatchType {
3237 Logged = 0 ,
3338 Unlogged = 1 ,
3439 Counter = 2 ,
3540}
3641
37- #[ derive( Debug , Clone , Copy , Eq , PartialEq , PartialOrd , Ord ) ]
42+ pub struct BatchTypeParseError {
43+ value : u8 ,
44+ }
45+
46+ impl From < BatchTypeParseError > for ParseError {
47+ fn from ( err : BatchTypeParseError ) -> Self {
48+ Self :: BadIncomingData ( format ! ( "Bad BatchType value: {}" , err. value) )
49+ }
50+ }
51+
52+ impl TryFrom < u8 > for BatchType {
53+ type Error = BatchTypeParseError ;
54+
55+ fn try_from ( value : u8 ) -> Result < Self , Self :: Error > {
56+ match value {
57+ 0 => Ok ( Self :: Logged ) ,
58+ 1 => Ok ( Self :: Unlogged ) ,
59+ 2 => Ok ( Self :: Counter ) ,
60+ _ => Err ( BatchTypeParseError { value } ) ,
61+ }
62+ }
63+ }
64+
65+ #[ derive( Debug , Clone , Eq , PartialEq , PartialOrd , Ord ) ]
3866pub enum BatchStatement < ' a > {
39- Query { text : & ' a str } ,
40- Prepared { id : & ' a Bytes } ,
67+ Query { text : Cow < ' a , str > } ,
68+ Prepared { id : Cow < ' a , [ u8 ] > } ,
4169}
4270
43- impl < ' a , StatementsIter , Values > Request for Batch < ' a , StatementsIter , Values >
71+ impl < Statement , Values > SerializableRequest for Batch < ' _ , Statement , Values >
4472where
45- StatementsIter : Iterator < Item = BatchStatement < ' a > > + Clone ,
73+ for < ' s > BatchStatement < ' s > : From < & ' s Statement > ,
74+ Statement : Clone ,
4675 Values : BatchValues ,
4776{
4877 const OPCODE : RequestOpcode = RequestOpcode :: Batch ;
5281 buf. put_u8 ( self . batch_type as u8 ) ;
5382
5483 // Serializing queries
55- types:: write_short ( self . statements_count . try_into ( ) ?, buf) ;
84+ types:: write_short ( self . statements . len ( ) . try_into ( ) ?, buf) ;
5685
5786 let counts_mismatch_err = |n_values : usize , n_statements : usize | {
5887 ParseError :: BadDataToSerialize ( format ! (
@@ -62,26 +91,27 @@ where
6291 } ;
6392 let mut n_serialized_statements = 0usize ;
6493 let mut value_lists = self . values . batch_values_iter ( ) ;
65- for ( idx, statement) in self . statements . clone ( ) . enumerate ( ) {
66- statement. serialize ( buf) ?;
94+ for ( idx, statement) in self . statements . iter ( ) . enumerate ( ) {
95+ BatchStatement :: from ( statement) . serialize ( buf) ?;
6796 value_lists
6897 . write_next_to_request ( buf)
69- . ok_or_else ( || counts_mismatch_err ( idx, self . statements . clone ( ) . count ( ) ) ) ??;
98+ . ok_or_else ( || counts_mismatch_err ( idx, self . statements . len ( ) ) ) ??;
7099 n_serialized_statements += 1 ;
71100 }
101+ // At this point, we have all statements serialized. If any values are still left, we have a mismatch.
72102 if value_lists. skip_next ( ) . is_some ( ) {
73103 return Err ( counts_mismatch_err (
74- std :: iter :: from_fn ( || value_lists. skip_next ( ) ) . count ( ) + 1 ,
104+ n_serialized_statements + 1 /*skipped above*/ + value_lists. count ( ) ,
75105 n_serialized_statements,
76106 ) ) ;
77107 }
78- if n_serialized_statements != self . statements_count {
108+ if n_serialized_statements != self . statements . len ( ) {
79109 // We want to check this to avoid propagating an invalid construction of self.statements_count as a
80110 // hard-to-debug silent fail
81111 return Err ( ParseError :: BadDataToSerialize ( format ! (
82112 "Invalid Batch constructed: not as many statements serialized as announced \
83113 (batch.statement_count: {announced_statement_count}, {n_serialized_statements}",
84- announced_statement_count = self . statements_count
114+ announced_statement_count = self . statements . len ( )
85115 ) ) ) ;
86116 }
87117
@@ -110,19 +140,115 @@ where
110140 }
111141}
112142
143+ impl BatchStatement < ' _ > {
144+ fn deserialize ( buf : & mut & [ u8 ] ) -> Result < Self , ParseError > {
145+ let kind = buf. get_u8 ( ) ;
146+ match kind {
147+ 0 => {
148+ let text = Cow :: Owned ( types:: read_long_string ( buf) ?. to_owned ( ) ) ;
149+ Ok ( BatchStatement :: Query { text } )
150+ }
151+ 1 => {
152+ let id = types:: read_short_bytes ( buf) ?. to_vec ( ) . into ( ) ;
153+ Ok ( BatchStatement :: Prepared { id } )
154+ }
155+ _ => Err ( ParseError :: BadIncomingData ( format ! (
156+ "Unexpected batch statement kind: {}" ,
157+ kind
158+ ) ) ) ,
159+ }
160+ }
161+ }
162+
113163impl BatchStatement < ' _ > {
114164 fn serialize ( & self , buf : & mut impl BufMut ) -> Result < ( ) , ParseError > {
115165 match self {
116- BatchStatement :: Query { text } => {
166+ Self :: Query { text } => {
117167 buf. put_u8 ( 0 ) ;
118168 types:: write_long_string ( text, buf) ?;
119169 }
120- BatchStatement :: Prepared { id } => {
170+ Self :: Prepared { id } => {
121171 buf. put_u8 ( 1 ) ;
122- types:: write_short_bytes ( & id [ .. ] , buf) ?;
172+ types:: write_short_bytes ( id , buf) ?;
123173 }
124174 }
125175
126176 Ok ( ( ) )
127177 }
128178}
179+
180+ impl < ' s , ' b > From < & ' s BatchStatement < ' b > > for BatchStatement < ' s > {
181+ fn from ( value : & ' s BatchStatement ) -> Self {
182+ match value {
183+ BatchStatement :: Query { text } => BatchStatement :: Query { text : text. clone ( ) } ,
184+ BatchStatement :: Prepared { id } => BatchStatement :: Prepared { id : id. clone ( ) } ,
185+ }
186+ }
187+ }
188+
189+ impl < ' b > DeserializableRequest for Batch < ' b , BatchStatement < ' b > , Vec < SerializedValues > > {
190+ fn deserialize ( buf : & mut & [ u8 ] ) -> Result < Self , ParseError > {
191+ let batch_type = buf. get_u8 ( ) . try_into ( ) ?;
192+
193+ let statements_count: usize = types:: read_short ( buf) ?. try_into ( ) ?;
194+ let statements_with_values = ( 0 ..statements_count)
195+ . map ( |_| {
196+ let batch_statement = BatchStatement :: deserialize ( buf) ?;
197+
198+ // As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
199+ let values = SerializedValues :: new_from_frame ( buf, false ) ?;
200+
201+ Ok ( ( batch_statement, values) )
202+ } )
203+ . collect :: < Result < Vec < _ > , ParseError > > ( ) ?;
204+
205+ let consistency = match types:: read_consistency ( buf) ? {
206+ types:: LegacyConsistency :: Regular ( reg) => Ok ( reg) ,
207+ types:: LegacyConsistency :: Serial ( ser) => Err ( ParseError :: BadIncomingData ( format ! (
208+ "Expected regular Consistency, got SerialConsistency {}" ,
209+ ser
210+ ) ) ) ,
211+ } ?;
212+
213+ let flags = buf. get_u8 ( ) ;
214+ let unknown_flags = flags & ( !ALL_FLAGS ) ;
215+ if unknown_flags != 0 {
216+ return Err ( ParseError :: BadIncomingData ( format ! (
217+ "Specified flags are not recognised: {:02x}" ,
218+ unknown_flags
219+ ) ) ) ;
220+ }
221+ let serial_consistency_flag = ( flags & FLAG_WITH_SERIAL_CONSISTENCY ) != 0 ;
222+ let default_timestamp_flag = ( flags & FLAG_WITH_DEFAULT_TIMESTAMP ) != 0 ;
223+
224+ let serial_consistency = serial_consistency_flag
225+ . then ( || types:: read_consistency ( buf) )
226+ . transpose ( ) ?
227+ . map ( |legacy_consistency| match legacy_consistency {
228+ types:: LegacyConsistency :: Regular ( reg) => {
229+ Err ( ParseError :: BadIncomingData ( format ! (
230+ "Expected SerialConsistency, got regular Consistency {}" ,
231+ reg
232+ ) ) )
233+ }
234+ types:: LegacyConsistency :: Serial ( ser) => Ok ( ser) ,
235+ } )
236+ . transpose ( ) ?;
237+
238+ let timestamp = default_timestamp_flag
239+ . then ( || types:: read_long ( buf) )
240+ . transpose ( ) ?;
241+
242+ let ( statements, values) : ( Vec < BatchStatement > , Vec < SerializedValues > ) =
243+ statements_with_values. into_iter ( ) . unzip ( ) ;
244+
245+ Ok ( Self {
246+ batch_type,
247+ consistency,
248+ serial_consistency,
249+ timestamp,
250+ statements : Cow :: Owned ( statements) ,
251+ values,
252+ } )
253+ }
254+ }
0 commit comments