Skip to content

Commit a71e3cf

Browse files
committed
collection: typecheck value on append
1 parent 7787468 commit a71e3cf

File tree

1 file changed

+212
-2
lines changed

1 file changed

+212
-2
lines changed

scylla-rust-wrapper/src/collection.rs

Lines changed: 212 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use crate::argconv::*;
21
use crate::cass_error::CassError;
32
use crate::cass_types::CassDataType;
43
use crate::types::*;
54
use crate::value::CassCqlValue;
5+
use crate::{argconv::*, value};
66
use std::convert::TryFrom;
77
use std::sync::Arc;
88

@@ -32,8 +32,49 @@ pub struct CassCollection {
3232
}
3333

3434
impl CassCollection {
35+
fn typecheck_on_append(&self, value: &Option<CassCqlValue>) -> CassError {
36+
// See https://github.com/scylladb/cpp-driver/blob/master/src/collection.hpp#L100.
37+
let index = self.items.len();
38+
39+
// Do validation only if it's a typed collection.
40+
if let Some(data_type) = &self.data_type {
41+
match data_type.as_ref() {
42+
CassDataType::List { typ: subtype, .. }
43+
| CassDataType::Set { typ: subtype, .. } => match subtype {
44+
Some(subtype) => {
45+
if !value::is_type_compatible(value, subtype) {
46+
return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE;
47+
}
48+
}
49+
None => {}
50+
},
51+
CassDataType::Map {
52+
key_type: k_typ,
53+
val_type: v_typ,
54+
..
55+
} => {
56+
// Only do the typecheck if both map types are present.
57+
if let (Some(k_typ), Some(v_typ)) = (k_typ, v_typ) {
58+
if index % 2 == 0 && !value::is_type_compatible(value, k_typ) {
59+
return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE;
60+
}
61+
if index % 2 != 0 && !value::is_type_compatible(value, v_typ) {
62+
return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE;
63+
}
64+
}
65+
}
66+
_ => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
67+
}
68+
}
69+
70+
CassError::CASS_OK
71+
}
72+
3573
pub fn append_cql_value(&mut self, value: Option<CassCqlValue>) -> CassError {
36-
// FIXME: Bounds check, type check
74+
let err = self.typecheck_on_append(&value);
75+
if err != CassError::CASS_OK {
76+
return err;
77+
}
3778
// There is no API to append null, so unwrap is safe
3879
self.items.push(value.unwrap());
3980
CassError::CASS_OK
@@ -160,3 +201,172 @@ make_binders!(decimal, cass_collection_append_decimal);
160201
make_binders!(collection, cass_collection_append_collection);
161202
make_binders!(tuple, cass_collection_append_tuple);
162203
make_binders!(user_type, cass_collection_append_user_type);
204+
205+
#[cfg(test)]
206+
mod tests {
207+
use std::sync::Arc;
208+
209+
use crate::{
210+
cass_error::CassError,
211+
cass_types::{CassDataType, CassValueType},
212+
collection::{
213+
cass_collection_append_double, cass_collection_append_float, cass_collection_free,
214+
},
215+
testing::assert_cass_error_eq,
216+
};
217+
218+
use super::{
219+
cass_bool_t, cass_collection_append_bool, cass_collection_append_int16,
220+
cass_collection_new, cass_collection_new_from_data_type, CassCollectionType,
221+
};
222+
223+
#[test]
224+
fn test_typecheck_on_append_to_collection() {
225+
unsafe {
226+
// untyped map
227+
{
228+
let untyped_map =
229+
cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_MAP, 2);
230+
assert_cass_error_eq!(
231+
cass_collection_append_bool(untyped_map, false as cass_bool_t),
232+
CassError::CASS_OK
233+
);
234+
assert_cass_error_eq!(
235+
cass_collection_append_int16(untyped_map, 42),
236+
CassError::CASS_OK
237+
);
238+
assert_cass_error_eq!(
239+
cass_collection_append_double(untyped_map, 42.42),
240+
CassError::CASS_OK
241+
);
242+
assert_cass_error_eq!(
243+
cass_collection_append_float(untyped_map, 42.42),
244+
CassError::CASS_OK
245+
);
246+
cass_collection_free(untyped_map);
247+
}
248+
249+
// typed map
250+
{
251+
let dt = Arc::new(CassDataType::Map {
252+
key_type: Some(Arc::new(CassDataType::Value(
253+
CassValueType::CASS_VALUE_TYPE_BOOLEAN,
254+
))),
255+
val_type: Some(Arc::new(CassDataType::Value(
256+
CassValueType::CASS_VALUE_TYPE_SMALL_INT,
257+
))),
258+
frozen: false,
259+
});
260+
let dt_ptr = Arc::into_raw(dt);
261+
let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr, 2);
262+
263+
// First entry -> typecheck successful.
264+
assert_cass_error_eq!(
265+
cass_collection_append_bool(bool_to_i16_map, false as cass_bool_t),
266+
CassError::CASS_OK
267+
);
268+
assert_cass_error_eq!(
269+
cass_collection_append_int16(bool_to_i16_map, 42),
270+
CassError::CASS_OK
271+
);
272+
273+
// Second entry -> key typecheck failed.
274+
assert_cass_error_eq!(
275+
cass_collection_append_float(bool_to_i16_map, 42.42),
276+
CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE
277+
);
278+
279+
// Third entry -> value typecheck failed.
280+
assert_cass_error_eq!(
281+
cass_collection_append_bool(bool_to_i16_map, true as cass_bool_t),
282+
CassError::CASS_OK
283+
);
284+
assert_cass_error_eq!(
285+
cass_collection_append_float(bool_to_i16_map, 42.42),
286+
CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE
287+
);
288+
289+
Arc::from_raw(dt_ptr);
290+
cass_collection_free(bool_to_i16_map);
291+
}
292+
293+
// untyped set
294+
{
295+
let untyped_set =
296+
cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_SET, 2);
297+
assert_cass_error_eq!(
298+
cass_collection_append_bool(untyped_set, false as cass_bool_t),
299+
CassError::CASS_OK
300+
);
301+
assert_cass_error_eq!(
302+
cass_collection_append_int16(untyped_set, 42),
303+
CassError::CASS_OK
304+
);
305+
cass_collection_free(untyped_set);
306+
}
307+
308+
// typed set
309+
{
310+
let dt = Arc::new(CassDataType::Set {
311+
typ: Some(Arc::new(CassDataType::Value(
312+
CassValueType::CASS_VALUE_TYPE_BOOLEAN,
313+
))),
314+
frozen: false,
315+
});
316+
let dt_ptr = Arc::into_raw(dt);
317+
let bool_set = cass_collection_new_from_data_type(dt_ptr, 2);
318+
319+
assert_cass_error_eq!(
320+
cass_collection_append_bool(bool_set, true as cass_bool_t),
321+
CassError::CASS_OK
322+
);
323+
assert_cass_error_eq!(
324+
cass_collection_append_float(bool_set, 42.42),
325+
CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE
326+
);
327+
328+
Arc::from_raw(dt_ptr);
329+
cass_collection_free(bool_set);
330+
}
331+
332+
// untyped list
333+
{
334+
let untyped_list =
335+
cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2);
336+
assert_cass_error_eq!(
337+
cass_collection_append_bool(untyped_list, false as cass_bool_t),
338+
CassError::CASS_OK
339+
);
340+
assert_cass_error_eq!(
341+
cass_collection_append_int16(untyped_list, 42),
342+
CassError::CASS_OK
343+
);
344+
cass_collection_free(untyped_list);
345+
}
346+
347+
// typed list
348+
{
349+
let dt = Arc::new(CassDataType::Set {
350+
typ: Some(Arc::new(CassDataType::Value(
351+
CassValueType::CASS_VALUE_TYPE_BOOLEAN,
352+
))),
353+
frozen: false,
354+
});
355+
let dt_ptr = Arc::into_raw(dt);
356+
let bool_list = cass_collection_new_from_data_type(dt_ptr, 2);
357+
358+
assert_cass_error_eq!(
359+
cass_collection_append_bool(bool_list, true as cass_bool_t),
360+
CassError::CASS_OK
361+
);
362+
assert_cass_error_eq!(
363+
cass_collection_append_float(bool_list, 42.42),
364+
CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE
365+
);
366+
367+
Arc::from_raw(dt_ptr);
368+
cass_collection_free(bool_list);
369+
}
370+
}
371+
}
372+
}

0 commit comments

Comments
 (0)