diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index beed77df0..250d3c01e 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -9,5 +9,6 @@ - [ ] I have split my patch into logically separate commits. - [ ] All commit messages clearly explain what they change and why. - [ ] PR description sums up the changes and reasons why they should be introduced. +- [ ] I have implemented Rust unit tests for the features/changes introduced. - [ ] I have enabled appropriate tests in `.github/workflows/build.yml` in `gtest_filter`. - [ ] I have enabled appropriate tests in `.github/workflows/cassandra.yml` in `gtest_filter`. \ No newline at end of file diff --git a/README.md b/README.md index 5ab1ddf44..5b37bfa61 100644 --- a/README.md +++ b/README.md @@ -172,13 +172,6 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as: Collection - - cass_collection_new_from_data_type - Unimplemented - - - cass_collection_data_type - cass_collection_append_custom[_n] Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the appended value is compatible with the type of the collection items. diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index d4890402f..3cdf36eef 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -165,7 +165,7 @@ pub unsafe extern "C" fn cass_batch_add_statement( match &statement.statement { Statement::Simple(q) => state.batch.append_statement(q.query.clone()), - Statement::Prepared(p) => state.batch.append_statement((**p).clone()), + Statement::Prepared(p) => state.batch.append_statement(p.statement.clone()), }; state.bound_values.push(statement.bound_values.clone()); diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index 4c1d37e4f..f23396047 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -47,12 +47,6 @@ //! It can be used for binding named parameter in CassStatement or field by name in CassUserType. //! * Functions from make_appender don't take any extra argument, as they are for use by CassCollection //! functions - values are appended to collection. -use crate::{cass_types::CassDataType, value::CassCqlValue}; - -pub fn is_compatible_type(_data_type: &CassDataType, _value: &Option) -> bool { - // TODO: cppdriver actually checks types. - true -} macro_rules! make_index_binder { ($this:ty, $consume_v:expr, $fn_by_idx:ident, $e:expr, [$($arg:ident @ $t:ty), *]) => { diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 6bd511926..f75540aef 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_batch_types.rs")); -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct UDTDataType { // Vec to preserve the order of types pub field_types: Vec<(String, Arc)>, @@ -87,6 +87,42 @@ impl UDTDataType { pub fn get_field_by_index(&self, index: usize) -> Option<&Arc> { self.field_types.get(index).map(|(_, b)| b) } + + fn typecheck_equals(&self, other: &UDTDataType) -> bool { + // See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386 + + if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) { + return false; + } + if !any_string_empty_or_both_equal(&self.name, &other.name) { + return false; + } + + // A comment from cpp-driver: + //// UDT's can be considered equal as long as the mutual first fields shared + //// between them are equal. UDT's are append only as far as fields go, so a + //// newer 'version' of the UDT data type after a schema change event should be + //// treated as equivalent in this scenario, by simply looking at the first N + //// mutual fields they should share. + // + // Iterator returned from zip() is perfect for checking the first mutual fields. + for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) { + // Compare field names. + if field.0 != other_field.0 { + return false; + } + // Compare field types. + if !field.1.typecheck_equals(&other_field.1) { + return false; + } + } + + true + } +} + +fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool { + s1.is_empty() || s2.is_empty() || s1 == s2 } impl Default for UDTDataType { @@ -95,27 +131,106 @@ impl Default for UDTDataType { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] +pub enum MapDataType { + Untyped, + Key(Arc), + KeyAndValue(Arc, Arc), +} + +#[derive(Clone, Debug, PartialEq)] pub enum CassDataType { Value(CassValueType), UDT(UDTDataType), List { + // None stands for untyped list. typ: Option>, frozen: bool, }, Set { + // None stands for untyped set. typ: Option>, frozen: bool, }, Map { - key_type: Option>, - val_type: Option>, + typ: MapDataType, frozen: bool, }, + // Empty vector stands for untyped tuple. Tuple(Vec>), Custom(String), } +impl CassDataType { + /// Checks for equality during typechecks. + /// + /// This takes into account the fact that tuples/collections may be untyped. + pub fn typecheck_equals(&self, other: &CassDataType) -> bool { + match self { + CassDataType::Value(t) => *t == other.get_value_type(), + CassDataType::UDT(udt) => match other { + CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt), + _ => false, + }, + CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match other { + CassDataType::List { typ: other_typ, .. } + | CassDataType::Set { typ: other_typ, .. } => { + // If one of them is list, and the other is set, fail the typecheck. + if self.get_value_type() != other.get_value_type() { + return false; + } + match (typ, other_typ) { + // One of them is untyped, skip the typecheck for subtype. + (None, _) | (_, None) => true, + (Some(typ), Some(other_typ)) => typ.typecheck_equals(other_typ), + } + } + _ => false, + }, + CassDataType::Map { typ: t, .. } => match other { + CassDataType::Map { typ: t_other, .. } => match (t, t_other) { + // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 + // In cpp-driver the types are held in a vector. + // The logic is following: + + // If either of vectors is empty, skip the typecheck. + (MapDataType::Untyped, _) => true, + (_, MapDataType::Untyped) => true, + + // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. + (MapDataType::Key(k), MapDataType::Key(k_other)) => k.typecheck_equals(k_other), + ( + MapDataType::KeyAndValue(k, v), + MapDataType::KeyAndValue(k_other, v_other), + ) => k.typecheck_equals(k_other) && v.typecheck_equals(v_other), + _ => false, + }, + _ => false, + }, + CassDataType::Tuple(sub) => match other { + CassDataType::Tuple(other_sub) => { + // If either of tuples is untyped, skip the typecheck for subtypes. + if sub.is_empty() || other_sub.is_empty() { + return true; + } + + // If both are non-empty, check for subtypes equality. + if sub.len() != other_sub.len() { + return false; + } + sub.iter() + .zip(other_sub.iter()) + .all(|(typ, other_typ)| typ.typecheck_equals(other_typ)) + } + _ => false, + }, + CassDataType::Custom(_) => { + unimplemented!("Cpp-rust-driver does not support custom types!") + } + } + } +} + impl From for CassValueType { fn from(native_type: NativeType) -> CassValueType { match native_type { @@ -160,16 +275,18 @@ pub fn get_column_type_from_cql_type( frozen: *frozen, }, CollectionType::Map(key, value) => CassDataType::Map { - key_type: Some(Arc::new(get_column_type_from_cql_type( - key, - user_defined_types, - keyspace_name, - ))), - val_type: Some(Arc::new(get_column_type_from_cql_type( - value, - user_defined_types, - keyspace_name, - ))), + typ: MapDataType::KeyAndValue( + Arc::new(get_column_type_from_cql_type( + key, + user_defined_types, + keyspace_name, + )), + Arc::new(get_column_type_from_cql_type( + value, + user_defined_types, + keyspace_name, + )), + ), frozen: *frozen, }, CollectionType::Set(set) => CassDataType::Set { @@ -222,10 +339,19 @@ impl CassDataType { } } CassDataType::Map { - key_type, val_type, .. + typ: MapDataType::Untyped, + .. + } => None, + CassDataType::Map { + typ: MapDataType::Key(k), + .. + } => (index == 0).then_some(k), + CassDataType::Map { + typ: MapDataType::KeyAndValue(k, v), + .. } => match index { - 0 => key_type.as_ref(), - 1 => val_type.as_ref(), + 0 => Some(k), + 1 => Some(v), _ => None, }, CassDataType::Tuple(v) => v.get(index), @@ -243,17 +369,28 @@ impl CassDataType { } }, CassDataType::Map { - key_type, val_type, .. + typ: MapDataType::KeyAndValue(_, _), + .. + } => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), + CassDataType::Map { + typ: MapDataType::Key(k), + frozen, } => { - if key_type.is_some() && val_type.is_some() { - Err(CassError::CASS_ERROR_LIB_BAD_PARAMS) - } else if key_type.is_none() { - *key_type = Some(sub_type); - Ok(()) - } else { - *val_type = Some(sub_type); - Ok(()) - } + *self = CassDataType::Map { + typ: MapDataType::KeyAndValue(k.clone(), sub_type), + frozen: *frozen, + }; + Ok(()) + } + CassDataType::Map { + typ: MapDataType::Untyped, + frozen, + } => { + *self = CassDataType::Map { + typ: MapDataType::Key(sub_type), + frozen: *frozen, + }; + Ok(()) } CassDataType::Tuple(types) => { types.push(sub_type); @@ -305,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { frozen: false, }, ColumnType::Map(key, value) => CassDataType::Map { - key_type: Some(Arc::new(get_column_type(key.as_ref()))), - val_type: Some(Arc::new(get_column_type(value.as_ref()))), + typ: MapDataType::KeyAndValue( + Arc::new(get_column_type(key.as_ref())), + Arc::new(get_column_type(value.as_ref())), + ), frozen: false, }, ColumnType::Set(boxed_type) => CassDataType::Set { @@ -357,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const }, CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()), CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map { - key_type: None, - val_type: None, + typ: MapDataType::Untyped, frozen: false, }, CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()), @@ -555,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat CassDataType::Value(..) => 0, CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t, - CassDataType::Map { - key_type, val_type, .. - } => key_type.is_some() as size_t + val_type.is_some() as size_t, + CassDataType::Map { typ, .. } => match typ { + MapDataType::Untyped => 0, + MapDataType::Key(_) => 1, + MapDataType::KeyAndValue(_, _) => 2, + }, CassDataType::Tuple(v) => v.len() as size_t, CassDataType::Custom(..) => 0, } diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index 5964a38f9..ea0c2a795 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -1,21 +1,86 @@ -use crate::argconv::*; use crate::cass_error::CassError; +use crate::cass_types::{CassDataType, MapDataType}; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use std::convert::TryFrom; +use std::sync::Arc; include!(concat!(env!("OUT_DIR"), "/cppdriver_data_collection.rs")); +// These constants help us to save an allocation in case user calls `cass_collection_new` (untyped collection). +static UNTYPED_LIST_TYPE: CassDataType = CassDataType::List { + typ: None, + frozen: false, +}; +static UNTYPED_SET_TYPE: CassDataType = CassDataType::Set { + typ: None, + frozen: false, +}; +static UNTYPED_MAP_TYPE: CassDataType = CassDataType::Map { + typ: MapDataType::Untyped, + frozen: false, +}; + #[derive(Clone)] pub struct CassCollection { pub collection_type: CassCollectionType, + pub data_type: Option>, pub capacity: usize, pub items: Vec, } impl CassCollection { + fn typecheck_on_append(&self, value: &Option) -> CassError { + // See https://github.com/scylladb/cpp-driver/blob/master/src/collection.hpp#L100. + let index = self.items.len(); + + // Do validation only if it's a typed collection. + if let Some(data_type) = &self.data_type { + match data_type.as_ref() { + CassDataType::List { typ: subtype, .. } + | CassDataType::Set { typ: subtype, .. } => { + if let Some(subtype) = subtype { + if !value::is_type_compatible(value, subtype) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + } + + CassDataType::Map { typ, .. } => { + // Cpp-driver does the typecheck only if both map types are present... + // However, we decided not to mimic this behaviour (which is probably a bug). + // We will do the typecheck if just the key type is defined as well (half-typed maps). + match typ { + MapDataType::Key(k_typ) => { + if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + MapDataType::KeyAndValue(k_typ, v_typ) => { + if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + if index % 2 != 0 && !value::is_type_compatible(value, v_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + // Skip the typecheck for untyped map. + MapDataType::Untyped => (), + } + } + _ => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } + } + + CassError::CASS_OK + } + pub fn append_cql_value(&mut self, value: Option) -> CassError { - // FIXME: Bounds check, type check + let err = self.typecheck_on_append(&value); + if err != CassError::CASS_OK { + return err; + } // There is no API to append null, so unwrap is safe self.items.push(value.unwrap()); CassError::CASS_OK @@ -26,10 +91,12 @@ impl TryFrom<&CassCollection> for CassCqlValue { type Error = (); fn try_from(collection: &CassCollection) -> Result { // FIXME: validate that collection items are correct + let data_type = collection.data_type.clone(); match collection.collection_type { - CassCollectionType::CASS_COLLECTION_TYPE_LIST => { - Ok(CassCqlValue::List(collection.items.clone())) - } + CassCollectionType::CASS_COLLECTION_TYPE_LIST => Ok(CassCqlValue::List { + data_type, + values: collection.items.clone(), + }), CassCollectionType::CASS_COLLECTION_TYPE_MAP => { let mut grouped_items = Vec::new(); // FIXME: validate even number of items @@ -40,11 +107,15 @@ impl TryFrom<&CassCollection> for CassCqlValue { grouped_items.push((key, value)); } - Ok(CassCqlValue::Map(grouped_items)) - } - CassCollectionType::CASS_COLLECTION_TYPE_SET => { - Ok(CassCqlValue::Set(collection.items.clone())) + Ok(CassCqlValue::Map { + data_type, + values: grouped_items, + }) } + CassCollectionType::CASS_COLLECTION_TYPE_SET => Ok(CassCqlValue::Set { + data_type, + values: collection.items.clone(), + }), _ => Err(()), } } @@ -57,18 +128,64 @@ pub unsafe extern "C" fn cass_collection_new( ) -> *mut CassCollection { let capacity = match collection_type { // Maps consist of a key and a value, so twice - // the number of CqlValue will be stored. + // the number of CassCqlValue will be stored. CassCollectionType::CASS_COLLECTION_TYPE_MAP => item_count * 2, _ => item_count, } as usize; Box::into_raw(Box::new(CassCollection { collection_type, + data_type: None, + capacity, + items: Vec::with_capacity(capacity), + })) +} + +#[no_mangle] +unsafe extern "C" fn cass_collection_new_from_data_type( + data_type: *const CassDataType, + item_count: size_t, +) -> *mut CassCollection { + let data_type = clone_arced(data_type); + let (capacity, collection_type) = match data_type.as_ref() { + CassDataType::List { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST), + CassDataType::Set { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_SET), + // Maps consist of a key and a value, so twice + // the number of CassCqlValue will be stored. + CassDataType::Map { .. } => (item_count * 2, CassCollectionType::CASS_COLLECTION_TYPE_MAP), + _ => return std::ptr::null_mut(), + }; + let capacity = capacity as usize; + + Box::into_raw(Box::new(CassCollection { + collection_type, + data_type: Some(data_type), capacity, items: Vec::with_capacity(capacity), })) } +#[no_mangle] +unsafe extern "C" fn cass_collection_data_type( + collection: *const CassCollection, +) -> *const CassDataType { + let collection_ref = ptr_to_ref(collection); + + match &collection_ref.data_type { + Some(dt) => Arc::as_ptr(dt), + None => match collection_ref.collection_type { + CassCollectionType::CASS_COLLECTION_TYPE_LIST => &UNTYPED_LIST_TYPE, + CassCollectionType::CASS_COLLECTION_TYPE_SET => &UNTYPED_SET_TYPE, + CassCollectionType::CASS_COLLECTION_TYPE_MAP => &UNTYPED_MAP_TYPE, + // CassCollectionType is a C enum. Panic, if it's out of range. + _ => panic!( + "CassCollectionType enum value out of range: {}", + collection_ref.collection_type.0 + ), + }, + } +} + #[no_mangle] pub unsafe extern "C" fn cass_collection_free(collection: *mut CassCollection) { free_boxed(collection); @@ -93,3 +210,282 @@ make_binders!(decimal, cass_collection_append_decimal); make_binders!(collection, cass_collection_append_collection); make_binders!(tuple, cass_collection_append_tuple); make_binders!(user_type, cass_collection_append_user_type); + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + cass_error::CassError, + cass_types::{CassDataType, CassValueType, MapDataType}, + collection::{ + cass_collection_append_double, cass_collection_append_float, cass_collection_free, + }, + testing::assert_cass_error_eq, + }; + + use super::{ + cass_bool_t, cass_collection_append_bool, cass_collection_append_int16, + cass_collection_new, cass_collection_new_from_data_type, CassCollectionType, + }; + + #[test] + fn test_typecheck_on_append_to_collection() { + unsafe { + // untyped map (via cass_collection_new, Collection's data type is None). + { + let untyped_map = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_MAP, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_map, 42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(untyped_map, 42.42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(untyped_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(untyped_map); + } + + // untyped map (via cass_collection_new_from_data_type - collection's type is Some(untyped_map)). + { + let dt = Arc::new(CassDataType::Map { + typ: MapDataType::Untyped, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_map = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_map, 42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(untyped_map, 42.42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(untyped_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(untyped_map); + } + + // half-typed map (key-only) + { + let dt = Arc::new(CassDataType::Map { + typ: MapDataType::Key(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let half_typed_map = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(half_typed_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(half_typed_map, 42), + CassError::CASS_OK + ); + + // Second entry -> key typecheck failed. + assert_cass_error_eq!( + cass_collection_append_double(half_typed_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + // Second entry -> typecheck succesful. + assert_cass_error_eq!( + cass_collection_append_bool(half_typed_map, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(half_typed_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(half_typed_map); + } + + // typed map + { + let dt = Arc::new(CassDataType::Map { + typ: MapDataType::KeyAndValue( + Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)), + Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_SMALL_INT, + )), + ), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr, 2); + + // First entry -> typecheck successful. + assert_cass_error_eq!( + cass_collection_append_bool(bool_to_i16_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(bool_to_i16_map, 42), + CassError::CASS_OK + ); + + // Second entry -> key typecheck failed. + assert_cass_error_eq!( + cass_collection_append_float(bool_to_i16_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + // Third entry -> value typecheck failed. + assert_cass_error_eq!( + cass_collection_append_bool(bool_to_i16_map, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_to_i16_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_to_i16_map); + } + + // untyped set (via cass_collection_new, collection's type is None) + { + let untyped_set = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_SET, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_set, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_set, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_set); + } + + // untyped set (via cass_collection_new_from_data_type, collection's type is Some(untyped_set)) + { + let dt = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_set = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_set, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_set, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_set); + } + + // typed set + { + let dt = Arc::new(CassDataType::Set { + typ: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_set = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(bool_set, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_set, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_set); + } + + // untyped list (via cass_collection_new, collection's type is None) + { + let untyped_list = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_list, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_list, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_list); + } + + // untyped list (via cass_collection_new_from_data_type, collection's type is Some(untyped_list)) + { + let dt = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_list = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_list, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_list, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_list); + } + + // typed list + { + let dt = Arc::new(CassDataType::Set { + typ: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_list = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(bool_list, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_list, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_list); + } + } + } +} diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index b11e99b05..c579dd5f7 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -8,7 +8,6 @@ use crate::types::*; use crate::uuid::CassUuid; use crate::RUNTIME; use futures::future; -use scylla::prepared_statement::PreparedStatement; use std::future::Future; use std::mem; use std::os::raw::c_void; @@ -20,7 +19,7 @@ pub enum CassResultValue { Empty, QueryResult(Arc), QueryError(Arc), - Prepared(Arc), + Prepared(Arc), } type CassFutureError = (CassError, String); diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 5094d6524..33fbcba6d 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -3,11 +3,32 @@ use std::sync::Arc; use crate::{ argconv::*, + cass_types::{get_column_type, CassDataType}, statement::{CassStatement, Statement}, }; use scylla::prepared_statement::PreparedStatement; -pub type CassPrepared = PreparedStatement; +#[derive(Debug, Clone)] +pub struct CassPrepared { + // Data types of columns from PreparedMetadata. + pub variable_col_data_types: Vec>, + pub statement: PreparedStatement, +} + +impl CassPrepared { + pub fn new_from_prepared_statement(statement: PreparedStatement) -> Self { + let variable_col_data_types = statement + .get_variable_col_specs() + .iter() + .map(|col_spec| Arc::new(get_column_type(&col_spec.typ))) + .collect(); + + Self { + variable_col_data_types, + statement, + } + } +} #[no_mangle] pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { @@ -19,7 +40,7 @@ pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: *const CassPrepared, ) -> *mut CassStatement { let prepared: Arc<_> = clone_arced(prepared_raw); - let bound_values_size = prepared.get_variable_col_specs().len(); + let bound_values_size = prepared.statement.get_variable_col_specs().len(); // cloning prepared statement's arc, because creating CassStatement should not invalidate // the CassPrepared argument diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index dec575bb4..4a98b9914 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1,6 +1,6 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType}; +use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType, MapDataType}; use crate::inet::CassInet; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, @@ -1239,7 +1239,7 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( } => list.get_value_type(), CassDataType::Set { typ: Some(set), .. } => set.get_value_type(), CassDataType::Map { - key_type: Some(key), + typ: MapDataType::Key(key) | MapDataType::KeyAndValue(key, _), .. } => key.get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, @@ -1254,7 +1254,7 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( match val.value_type.as_ref() { CassDataType::Map { - val_type: Some(value), + typ: MapDataType::KeyAndValue(_, value), .. } => value.get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 9e54737aa..2f5f17e84 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,13 +1,14 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; -use crate::cass_types::{get_column_type, CassDataType, UDTDataType}; +use crate::cass_types::{get_column_type, CassDataType, MapDataType, UDTDataType}; use crate::cluster::build_session_builder; use crate::cluster::CassCluster; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; use crate::metadata::{CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta}; +use crate::prepared::CassPrepared; use crate::query_result::Value::{CollectionValue, RegularValue}; use crate::query_result::{CassResult, CassResultData, CassRow, CassValue, Collection, Value}; use crate::statement::CassStatement; @@ -279,9 +280,9 @@ pub unsafe extern "C" fn cass_session_execute( match &mut statement { Statement::Simple(query) => query.query.set_execution_profile_handle(handle), - Statement::Prepared(prepared) => { - Arc::make_mut(prepared).set_execution_profile_handle(handle) - } + Statement::Prepared(prepared) => Arc::make_mut(prepared) + .statement + .set_execution_profile_handle(handle), } let query_res: Result<(QueryResult, PagingStateResponse), QueryError> = match statement { @@ -300,11 +301,11 @@ pub unsafe extern "C" fn cass_session_execute( Statement::Prepared(prepared) => { if paging_enabled { session - .execute_single_page(&prepared, bound_values, paging_state) + .execute_single_page(&prepared.statement, bound_values, paging_state) .await } else { session - .execute_unpaged(&prepared, bound_values) + .execute_unpaged(&prepared.statement, bound_values) .await .map(|result| (result, PagingStateResponse::NoMorePages)) } @@ -387,8 +388,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value ( CqlValue::Map(map), CassDataType::Map { - key_type: Some(key_typ), - val_type: Some(value_type), + typ: MapDataType::KeyAndValue(key_type, value_type), .. }, ) => CollectionValue(Collection::Map( @@ -396,8 +396,8 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value .map(|(key, val)| { ( CassValue { - value_type: key_typ.clone(), - value: Some(get_column_value(key, key_typ)), + value_type: key_type.clone(), + value: Some(get_column_value(key, key_type)), }, CassValue { value_type: value_type.clone(), @@ -499,7 +499,9 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( .await .map_err(|err| (CassError::from(&err), err.msg()))?; - Ok(CassResultValue::Prepared(Arc::new(prepared))) + Ok(CassResultValue::Prepared(Arc::new( + CassPrepared::new_from_prepared_statement(prepared), + ))) }) } @@ -542,7 +544,9 @@ pub unsafe extern "C" fn cass_session_prepare_n( // Set Cpp Driver default configuration for queries: prepared.set_consistency(Consistency::One); - Ok(CassResultValue::Prepared(Arc::new(prepared))) + Ok(CassResultValue::Prepared(Arc::new( + CassPrepared::new_from_prepared_statement(prepared), + ))) }) } diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index b1129ad91..dc955d79c 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,15 +1,15 @@ -use crate::argconv::*; use crate::cass_error::CassError; use crate::exec_profile::PerStatementExecProfile; +use crate::prepared::CassPrepared; use crate::query_result::CassResult; use crate::retry_policy::CassRetryPolicy; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use scylla::frame::types::Consistency; use scylla::frame::value::MaybeUnset; use scylla::frame::value::MaybeUnset::{Set, Unset}; use scylla::query::Query; -use scylla::statement::prepared_statement::PreparedStatement; use scylla::statement::SerialConsistency; use scylla::transport::{PagingState, PagingStateResponse}; use std::collections::HashMap; @@ -24,7 +24,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs")); pub enum Statement { Simple(SimpleQuery), // Arc is needed, because PreparedStatement is passed by reference to session.execute - Prepared(Arc), + Prepared(Arc), } #[derive(Clone)] @@ -45,12 +45,35 @@ pub struct CassStatement { impl CassStatement { fn bind_cql_value(&mut self, index: usize, value: Option) -> CassError { - if index >= self.bound_values.len() { - CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS - } else { - self.bound_values[index] = Set(value); - CassError::CASS_OK + let (bound_value, maybe_data_type) = match &self.statement { + Statement::Simple(_) => match self.bound_values.get_mut(index) { + Some(v) => (v, None), + None => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + }, + Statement::Prepared(p) => match ( + self.bound_values.get_mut(index), + p.variable_col_data_types.get(index), + ) { + (Some(v), Some(dt)) => (v, Some(dt)), + (None, None) => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + // This indicates a length mismatch between col specs table and self.bound_values. + // + // It can only occur when user provides bad `count` value in `cass_statement_reset_parameters`. + // Cpp-driver does not verify that both of these values are equal. + // I believe returning CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS is best we can do here. + _ => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + }, + }; + + // Perform the typecheck. + if let Some(dt) = maybe_data_type { + if !value::is_type_compatible(&value, dt) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } } + + *bound_value = Set(value); + CassError::CASS_OK } fn bind_multiple_values_by_name( @@ -83,6 +106,7 @@ impl CassStatement { match &self.statement { Statement::Prepared(prepared) => { let indices: Vec = prepared + .statement .get_variable_col_specs() .iter() .enumerate() @@ -185,7 +209,9 @@ pub unsafe extern "C" fn cass_statement_set_consistency( if let Some(consistency) = consistency_opt { match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_consistency(consistency), - Statement::Prepared(inner) => Arc::make_mut(inner).set_consistency(consistency), + Statement::Prepared(inner) => { + Arc::make_mut(inner).statement.set_consistency(consistency) + } } } @@ -205,7 +231,7 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( statement.paging_enabled = true; match &mut statement.statement { Statement::Simple(inner) => inner.query.set_page_size(page_size), - Statement::Prepared(inner) => Arc::make_mut(inner).set_page_size(page_size), + Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_page_size(page_size), } } @@ -253,7 +279,9 @@ pub unsafe extern "C" fn cass_statement_set_is_idempotent( ) -> CassError { match &mut ptr_to_ref_mut(statement_raw).statement { Statement::Simple(inner) => inner.query.set_is_idempotent(is_idempotent != 0), - Statement::Prepared(inner) => Arc::make_mut(inner).set_is_idempotent(is_idempotent != 0), + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_is_idempotent(is_idempotent != 0), } CassError::CASS_OK @@ -266,7 +294,7 @@ pub unsafe extern "C" fn cass_statement_set_tracing( ) -> CassError { match &mut ptr_to_ref_mut(statement_raw).statement { Statement::Simple(inner) => inner.query.set_tracing(enabled != 0), - Statement::Prepared(inner) => Arc::make_mut(inner).set_tracing(enabled != 0), + Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_tracing(enabled != 0), } CassError::CASS_OK @@ -288,9 +316,9 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_retry_policy(maybe_arced_retry_policy), - Statement::Prepared(inner) => { - Arc::make_mut(inner).set_retry_policy(maybe_arced_retry_policy) - } + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_retry_policy(maybe_arced_retry_policy), } CassError::CASS_OK @@ -317,9 +345,9 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_serial_consistency(Some(consistency)), - Statement::Prepared(inner) => { - Arc::make_mut(inner).set_serial_consistency(Some(consistency)) - } + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_serial_consistency(Some(consistency)), } CassError::CASS_OK @@ -349,7 +377,9 @@ pub unsafe extern "C" fn cass_statement_set_timestamp( ) -> CassError { match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_timestamp(Some(timestamp)), - Statement::Prepared(inner) => Arc::make_mut(inner).set_timestamp(Some(timestamp)), + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_timestamp(Some(timestamp)), } CassError::CASS_OK diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 941a2f4ca..df1f98892 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -1,12 +1,12 @@ use crate::argconv::*; -use crate::binding; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::types::*; +use crate::value; use crate::value::CassCqlValue; use std::sync::Arc; -static EMPTY_TUPLE_TYPE: CassDataType = CassDataType::Tuple(Vec::new()); +static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::Tuple(Vec::new()); #[derive(Clone)] pub struct CassTuple { @@ -37,7 +37,7 @@ impl CassTuple { } if let Some(inner_types) = self.get_types() { - if !binding::is_compatible_type(&inner_types[index], &v) { + if !value::is_type_compatible(&v, &inner_types[index]) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } } @@ -50,7 +50,10 @@ impl CassTuple { impl From<&CassTuple> for CassCqlValue { fn from(tuple: &CassTuple) -> Self { - CassCqlValue::Tuple(tuple.items.clone()) + CassCqlValue::Tuple { + data_type: tuple.data_type.clone(), + fields: tuple.items.clone(), + } } } @@ -86,7 +89,7 @@ unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) { unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType { match &ptr_to_ref(tuple).data_type { Some(t) => Arc::as_ptr(t), - None => &EMPTY_TUPLE_TYPE, + None => &UNTYPED_TUPLE_TYPE, } } diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 5082b6ece..e723c7b7a 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -1,9 +1,8 @@ -use crate::argconv::*; -use crate::binding::is_compatible_type; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use std::os::raw::c_char; use std::sync::Arc; @@ -20,7 +19,7 @@ impl CassUserType { if index >= self.field_values.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - if !is_compatible_type(&self.data_type.get_udt_type().field_types[index].1, &value) { + if !value::is_type_compatible(&value, &self.data_type.get_udt_type().field_types[index].1) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } self.field_values[index] = value; @@ -37,7 +36,7 @@ impl CassUserType { if index >= self.field_values.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - if !is_compatible_type(field_type, &value) { + if !value::is_type_compatible(&value, field_type) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } self.field_values[index].clone_from(&value); @@ -55,8 +54,7 @@ impl CassUserType { impl From<&CassUserType> for CassCqlValue { fn from(user_type: &CassUserType) -> Self { CassCqlValue::UserDefinedType { - keyspace: user_type.data_type.get_udt_type().keyspace.clone(), - type_name: user_type.data_type.get_udt_type().name.clone(), + data_type: user_type.data_type.clone(), fields: user_type .field_values .iter() diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index ffd02feb1..7636878f8 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, net::IpAddr}; +use std::{convert::TryInto, net::IpAddr, sync::Arc}; use scylla::{ frame::{ @@ -17,6 +17,8 @@ use scylla::{ }; use uuid::Uuid; +use crate::cass_types::{CassDataType, CassValueType}; + /// A narrower version of rust driver's CqlValue. /// /// cpp-driver's API allows to map single rust type to @@ -45,13 +47,24 @@ pub enum CassCqlValue { Inet(IpAddr), Duration(CqlDuration), Decimal(CqlDecimal), - Tuple(Vec>), - List(Vec), - Map(Vec<(CassCqlValue, CassCqlValue)>), - Set(Vec), + Tuple { + data_type: Option>, + fields: Vec>, + }, + List { + data_type: Option>, + values: Vec, + }, + Map { + data_type: Option>, + values: Vec<(CassCqlValue, CassCqlValue)>, + }, + Set { + data_type: Option>, + values: Vec, + }, UserDefinedType { - keyspace: String, - type_name: String, + data_type: Arc, /// Order of `fields` vector must match the order of fields as defined in the UDT. The /// driver does not check it by itself, so incorrect data will be written if the order is /// wrong. @@ -60,13 +73,108 @@ pub enum CassCqlValue { // TODO: custom (?), duration and decimal } +pub fn is_type_compatible(value: &Option, typ: &CassDataType) -> bool { + match value { + Some(v) => v.is_type_compatible(typ), + None => true, + } +} + +impl CassCqlValue { + pub fn is_type_compatible(&self, typ: &CassDataType) -> bool { + match self { + CassCqlValue::TinyInt(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TINY_INT + } + CassCqlValue::SmallInt(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT + } + CassCqlValue::Int(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INT, + CassCqlValue::BigInt(_) => { + matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_BIGINT + | CassValueType::CASS_VALUE_TYPE_COUNTER + | CassValueType::CASS_VALUE_TYPE_TIMESTAMP + | CassValueType::CASS_VALUE_TYPE_TIME + ) + } + CassCqlValue::Float(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_FLOAT, + CassCqlValue::Double(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DOUBLE + } + CassCqlValue::Boolean(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_BOOLEAN + } + CassCqlValue::Text(_) => { + matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_TEXT + | CassValueType::CASS_VALUE_TYPE_VARCHAR + | CassValueType::CASS_VALUE_TYPE_ASCII + | CassValueType::CASS_VALUE_TYPE_BLOB + | CassValueType::CASS_VALUE_TYPE_VARINT + ) + } + CassCqlValue::Blob(_) => matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT + ), + CassCqlValue::Uuid(_) => matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID + ), + CassCqlValue::Date(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE, + CassCqlValue::Inet(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INET, + CassCqlValue::Duration(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION + } + CassCqlValue::Decimal(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DECIMAL + } + CassCqlValue::Tuple { data_type, .. } => { + if let Some(dt) = data_type { + return dt.typecheck_equals(typ); + } + // Untyped tuple. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TUPLE + } + CassCqlValue::List { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped list. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_LIST + } + } + CassCqlValue::Map { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped map. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_MAP + } + } + CassCqlValue::Set { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped set. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SET + } + } + CassCqlValue::UserDefinedType { data_type, .. } => data_type.typecheck_equals(typ), + } + } +} + impl SerializeValue for CassCqlValue { fn serialize<'b>( &self, _typ: &ColumnType, writer: CellWriter<'b>, ) -> Result, SerializationError> { - // _typ is not used, since we do the typechecks during binding (this is still a TODO, high priority). + // _typ is not used, since we do the typechecks during binding. // This is the same approach as cpp-driver. self.do_serialize(writer) } @@ -128,12 +236,16 @@ impl CassCqlValue { CassCqlValue::Decimal(v) => { ::serialize(v, &ColumnType::Decimal, writer) } - CassCqlValue::Tuple(fields) => serialize_tuple_like(fields.iter(), writer), - CassCqlValue::List(l) => serialize_sequence(l.len(), l.iter(), writer), - CassCqlValue::Map(m) => { - serialize_mapping(m.len(), m.iter().map(|p| (&p.0, &p.1)), writer) + CassCqlValue::Tuple { fields, .. } => serialize_tuple_like(fields.iter(), writer), + CassCqlValue::List { values, .. } => { + serialize_sequence(values.len(), values.iter(), writer) + } + CassCqlValue::Map { values, .. } => { + serialize_mapping(values.len(), values.iter().map(|p| (&p.0, &p.1)), writer) + } + CassCqlValue::Set { values, .. } => { + serialize_sequence(values.len(), values.iter(), writer) } - CassCqlValue::Set(s) => serialize_sequence(s.len(), s.iter(), writer), CassCqlValue::UserDefinedType { fields, .. } => serialize_udt(fields, writer), } } @@ -282,3 +394,630 @@ fn serialize_udt<'b>( .finish() .map_err(|_| mk_ser_err::(BuiltinSerializationErrorKind::SizeOverflow)) } + +#[cfg(test)] +mod tests { + use std::{net::Ipv4Addr, sync::Arc}; + + use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; + + use crate::{ + cass_types::{CassDataType, CassValueType, MapDataType, UDTDataType}, + value::{is_type_compatible, CassCqlValue}, + }; + + fn all_value_data_types() -> [CassDataType; 26] { + let from = |v_typ: CassValueType| CassDataType::Value(v_typ); + + [ + from(CassValueType::CASS_VALUE_TYPE_TINY_INT), + from(CassValueType::CASS_VALUE_TYPE_SMALL_INT), + from(CassValueType::CASS_VALUE_TYPE_INT), + from(CassValueType::CASS_VALUE_TYPE_BIGINT), + from(CassValueType::CASS_VALUE_TYPE_COUNTER), + from(CassValueType::CASS_VALUE_TYPE_TIME), + from(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), + from(CassValueType::CASS_VALUE_TYPE_FLOAT), + from(CassValueType::CASS_VALUE_TYPE_DOUBLE), + from(CassValueType::CASS_VALUE_TYPE_BOOLEAN), + from(CassValueType::CASS_VALUE_TYPE_TEXT), + from(CassValueType::CASS_VALUE_TYPE_VARCHAR), + from(CassValueType::CASS_VALUE_TYPE_ASCII), + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_UUID), + from(CassValueType::CASS_VALUE_TYPE_TIMEUUID), + from(CassValueType::CASS_VALUE_TYPE_DATE), + from(CassValueType::CASS_VALUE_TYPE_INET), + from(CassValueType::CASS_VALUE_TYPE_DURATION), + from(CassValueType::CASS_VALUE_TYPE_DECIMAL), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + from(CassValueType::CASS_VALUE_TYPE_TUPLE), + from(CassValueType::CASS_VALUE_TYPE_LIST), + from(CassValueType::CASS_VALUE_TYPE_SET), + from(CassValueType::CASS_VALUE_TYPE_MAP), + from(CassValueType::CASS_VALUE_TYPE_UDT), + ] + } + + #[test] + fn typecheck_simple_test() { + let from = |v_typ: CassValueType| CassDataType::Value(v_typ); + struct TestCase { + value: Option, + compatible_types: Vec, + } + + let test_cases = [ + // Null -> all types + TestCase { + value: None, + compatible_types: all_value_data_types().to_vec(), + }, + // i8 -> tinyint + TestCase { + value: Some(CassCqlValue::TinyInt(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_TINY_INT)], + }, + // i16 -> smallint + TestCase { + value: Some(CassCqlValue::SmallInt(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_SMALL_INT)], + }, + // i32 -> int + TestCase { + value: Some(CassCqlValue::Int(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INT)], + }, + // i64 -> bigint/counter/time/timestamp + TestCase { + value: Some(CassCqlValue::BigInt(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_BIGINT), + from(CassValueType::CASS_VALUE_TYPE_COUNTER), + from(CassValueType::CASS_VALUE_TYPE_TIME), + from(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), + ], + }, + // f32 -> float + TestCase { + value: Some(CassCqlValue::Float(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_FLOAT)], + }, + // f64 -> double + TestCase { + value: Some(CassCqlValue::Double(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DOUBLE)], + }, + // bool -> boolean + TestCase { + value: Some(CassCqlValue::Boolean(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_BOOLEAN)], + }, + TestCase { + value: Some(CassCqlValue::Text(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_TEXT), + from(CassValueType::CASS_VALUE_TYPE_VARCHAR), + from(CassValueType::CASS_VALUE_TYPE_ASCII), + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + ], + }, + // Vec -> blob/varint + TestCase { + value: Some(CassCqlValue::Blob(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + ], + }, + // uuid -> uuid/timeuuid + TestCase { + value: Some(CassCqlValue::Uuid(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_UUID), + from(CassValueType::CASS_VALUE_TYPE_TIMEUUID), + ], + }, + // u32 -> date + TestCase { + value: Some(CassCqlValue::Date(CqlDate(Default::default()))), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DATE)], + }, + // IpAddr -> inet + TestCase { + value: Some(CassCqlValue::Inet(std::net::IpAddr::V4( + Ipv4Addr::LOCALHOST, + ))), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INET)], + }, + // CqlDuration -> duration + TestCase { + value: Some(CassCqlValue::Duration(CqlDuration { + months: 0, + days: 0, + nanoseconds: 0, + })), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DURATION)], + }, + // CqlDecimal -> decimal + TestCase { + value: Some(CassCqlValue::Decimal( + CqlDecimal::from_signed_be_bytes_slice_and_exponent(&[], 0), + )), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DECIMAL)], + }, + ]; + let all_simple_types = all_value_data_types(); + + for case in test_cases { + for typ in all_simple_types.iter() { + let result = is_type_compatible(&case.value, typ); + let expected = case.compatible_types.iter().any(|t| t == typ); + assert_eq!( + expected, result, + "Typecheck test for value {:?} and type {:?} failed. Expected result for the typecheck: {}", + case.value, typ, expected, + ); + } + } + } + + #[test] + fn typecheck_complex_test() { + struct TestCase { + value: CassCqlValue, + compatible_types: Vec>, + incompatible_types: Vec>, + } + + let run_test_cases = |test_cases: &[TestCase]| { + for case in test_cases { + for typ in case.compatible_types.iter() { + assert!( + case.value.is_type_compatible(typ), + "Typecheck failed, when it should pass. Value: {:?}, Type: {:?}", + case.value, + typ + ); + } + for typ in case.incompatible_types.iter() { + assert!( + !case.value.is_type_compatible(typ), + "Typecheck passed, when it should fail. Value: {:?}, Type: {:?}", + case.value, + typ + ) + } + } + }; + + // Let's make some types accessible for all test cases. + // To make sure that e.g. Tuple against UDT typecheck fails. + let data_type_float = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_FLOAT)); + let data_type_int = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INT)); + let data_type_bool = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)); + let data_type_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + ])); + + let simple_fields = vec![ + ("foo".to_owned(), data_type_float.clone()), + ("bar".to_owned(), data_type_bool.clone()), + ("baz".to_owned(), data_type_int.clone()), + ]; + let ks_keyspace_name = "ks".to_owned(); + let user_udt_name = "user".to_owned(); + let empty_str = "".to_owned(); + + let data_type_udt_simple = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: user_udt_name.clone(), + frozen: false, + })); + + let data_type_int_list = Arc::new(CassDataType::List { + typ: Some(data_type_int.clone()), + frozen: false, + }); + + let data_type_int_set = Arc::new(CassDataType::Set { + typ: Some(data_type_int.clone()), + frozen: false, + }); + + let data_type_bool_float_map = Arc::new(CassDataType::Map { + typ: MapDataType::KeyAndValue(data_type_bool.clone(), data_type_float.clone()), + frozen: false, + }); + + // TUPLES + { + let data_type_untyped_tuple = Arc::new(CassDataType::Tuple(vec![])); + let data_type_small_tuple = Arc::new(CassDataType::Tuple(vec![data_type_bool.clone()])); + let data_type_nested_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_small_tuple.clone(), + data_type_int.clone(), + data_type_tuple.clone(), + ])); + let data_type_nested_untyped_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_untyped_tuple.clone(), + data_type_int.clone(), + data_type_untyped_tuple.clone(), + ])); + + let test_cases = &[ + // Untyped tuple -> created via `cass_tuple_new` + TestCase { + value: CassCqlValue::Tuple { + data_type: None, + fields: vec![], + }, + compatible_types: vec![ + data_type_untyped_tuple.clone(), + data_type_small_tuple.clone(), + data_type_tuple.clone(), + data_type_nested_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Untyped tuple -> used created an untyped tuple data type, and then + // created a tuple value via `cass_tuple_new_from_data_type`. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_untyped_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_untyped_tuple.clone(), + data_type_small_tuple.clone(), + data_type_tuple.clone(), + data_type_nested_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Fully typed tuple. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_tuple.clone(), + data_type_untyped_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_small_tuple.clone(), + data_type_nested_tuple.clone(), + data_type_nested_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Nested tuple. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_nested_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_nested_tuple.clone(), + data_type_untyped_tuple.clone(), + data_type_nested_untyped_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_small_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), + ], + }, + ]; + + run_test_cases(test_cases); + } + + // UDT + { + let data_type_udt_simple_empty_keyspace = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: empty_str.to_owned(), + name: user_udt_name.clone(), + frozen: false, + })); + let data_type_udt_simple_empty_name = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: empty_str.clone(), + frozen: false, + })); + + // A prefix of simple_fields. + let small_fields = vec![ + ("foo".to_owned(), data_type_float.clone()), + ("bar".to_owned(), data_type_bool.clone()), + ]; + let data_type_udt_small = Arc::new(CassDataType::UDT(UDTDataType { + field_types: small_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: user_udt_name.clone(), + frozen: false, + })); + + let test_cases = &[TestCase { + value: CassCqlValue::UserDefinedType { + data_type: data_type_udt_simple.clone(), + fields: vec![], + }, + compatible_types: vec![ + data_type_udt_simple.clone(), + data_type_udt_simple_empty_keyspace.clone(), + data_type_udt_simple_empty_name.clone(), + data_type_udt_small.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), + ], + }]; + + run_test_cases(test_cases); + } + + // COLLECTIONS + { + let data_type_untyped_list = Arc::new(CassDataType::List { + typ: None, + frozen: false, + }); + let data_type_float_list = Arc::new(CassDataType::List { + typ: Some(data_type_float.clone()), + frozen: false, + }); + + let data_type_untyped_set = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + let data_type_float_set = Arc::new(CassDataType::Set { + typ: Some(data_type_float.clone()), + frozen: false, + }); + + let data_type_untyped_map = Arc::new(CassDataType::Map { + typ: MapDataType::Untyped, + frozen: false, + }); + let data_type_typed_key_float_map = Arc::new(CassDataType::Map { + typ: MapDataType::Key(data_type_float.clone()), + + frozen: false, + }); + let data_type_float_int_map = Arc::new(CassDataType::Map { + typ: MapDataType::KeyAndValue(data_type_float.clone(), data_type_int.clone()), + frozen: false, + }); + + let test_cases = &[ + // Untyped list -> user created it via `cass_collection_new`. + TestCase { + value: CassCqlValue::List { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Typed list. + TestCase { + value: CassCqlValue::List { + data_type: Some(data_type_float_list.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_float_list.clone(), + data_type_untyped_list.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Untyped set (via cass_collection_new). + TestCase { + value: CassCqlValue::Set { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_float_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Typed set. + TestCase { + value: CassCqlValue::Set { + data_type: Some(data_type_float_set.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_set.clone(), + data_type_float_set.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_float_list.clone(), + data_type_untyped_list.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Untyped map (via cass_collection_new). + TestCase { + value: CassCqlValue::Map { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + ], + }, + // Only key-typed map. + TestCase { + value: CassCqlValue::Map { + data_type: Some(data_type_typed_key_float_map.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_typed_key_float_map.clone(), + data_type_untyped_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Fully typed map + TestCase { + value: CassCqlValue::Map { + data_type: Some(data_type_float_int_map.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_float_int_map.clone(), + data_type_untyped_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_typed_key_float_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + ]; + + run_test_cases(test_cases) + } + } +}