Skip to content

Commit af43106

Browse files
authored
Merge pull request #143 from muzarski/typechecks
binding: typechecks
2 parents 9ac038f + 8dbe9d1 commit af43106

File tree

14 files changed

+1438
-120
lines changed

14 files changed

+1438
-120
lines changed

.github/pull_request_template.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
- [ ] I have split my patch into logically separate commits.
1010
- [ ] All commit messages clearly explain what they change and why.
1111
- [ ] PR description sums up the changes and reasons why they should be introduced.
12+
- [ ] I have implemented Rust unit tests for the features/changes introduced.
1213
- [ ] I have enabled appropriate tests in `.github/workflows/build.yml` in `gtest_filter`.
1314
- [ ] I have enabled appropriate tests in `.github/workflows/cassandra.yml` in `gtest_filter`.

README.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,6 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as:
172172
<tr>
173173
<td colspan=2 align="center" style="font-weight:bold">Collection</td>
174174
</tr>
175-
<tr>
176-
<td>cass_collection_new_from_data_type</td>
177-
<td rowspan="2">Unimplemented</td>
178-
</tr>
179-
<tr>
180-
<td>cass_collection_data_type</td>
181-
</tr>
182175
<tr>
183176
<td>cass_collection_append_custom[_n]</td>
184177
<td>Unimplemented because of the same reasons as binding for statements.<br> <b>Note</b>: The driver does not check whether the type of the appended value is compatible with the type of the collection items.</td>

scylla-rust-wrapper/src/batch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ pub unsafe extern "C" fn cass_batch_add_statement(
165165

166166
match &statement.statement {
167167
Statement::Simple(q) => state.batch.append_statement(q.query.clone()),
168-
Statement::Prepared(p) => state.batch.append_statement((**p).clone()),
168+
Statement::Prepared(p) => state.batch.append_statement(p.statement.clone()),
169169
};
170170

171171
state.bound_values.push(statement.bound_values.clone());

scylla-rust-wrapper/src/binding.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@
4747
//! It can be used for binding named parameter in CassStatement or field by name in CassUserType.
4848
//! * Functions from make_appender don't take any extra argument, as they are for use by CassCollection
4949
//! functions - values are appended to collection.
50-
use crate::{cass_types::CassDataType, value::CassCqlValue};
51-
52-
pub fn is_compatible_type(_data_type: &CassDataType, _value: &Option<CassCqlValue>) -> bool {
53-
// TODO: cppdriver actually checks types.
54-
true
55-
}
5650
5751
macro_rules! make_index_binder {
5852
($this:ty, $consume_v:expr, $fn_by_idx:ident, $e:expr, [$($arg:ident @ $t:ty), *]) => {

scylla-rust-wrapper/src/cass_types.rs

Lines changed: 174 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs"));
1515
include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs"));
1616
include!(concat!(env!("OUT_DIR"), "/cppdriver_batch_types.rs"));
1717

18-
#[derive(Clone, Debug)]
18+
#[derive(Clone, Debug, PartialEq)]
1919
pub struct UDTDataType {
2020
// Vec to preserve the order of types
2121
pub field_types: Vec<(String, Arc<CassDataType>)>,
@@ -87,6 +87,42 @@ impl UDTDataType {
8787
pub fn get_field_by_index(&self, index: usize) -> Option<&Arc<CassDataType>> {
8888
self.field_types.get(index).map(|(_, b)| b)
8989
}
90+
91+
fn typecheck_equals(&self, other: &UDTDataType) -> bool {
92+
// See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386
93+
94+
if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) {
95+
return false;
96+
}
97+
if !any_string_empty_or_both_equal(&self.name, &other.name) {
98+
return false;
99+
}
100+
101+
// A comment from cpp-driver:
102+
//// UDT's can be considered equal as long as the mutual first fields shared
103+
//// between them are equal. UDT's are append only as far as fields go, so a
104+
//// newer 'version' of the UDT data type after a schema change event should be
105+
//// treated as equivalent in this scenario, by simply looking at the first N
106+
//// mutual fields they should share.
107+
//
108+
// Iterator returned from zip() is perfect for checking the first mutual fields.
109+
for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) {
110+
// Compare field names.
111+
if field.0 != other_field.0 {
112+
return false;
113+
}
114+
// Compare field types.
115+
if !field.1.typecheck_equals(&other_field.1) {
116+
return false;
117+
}
118+
}
119+
120+
true
121+
}
122+
}
123+
124+
fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool {
125+
s1.is_empty() || s2.is_empty() || s1 == s2
90126
}
91127

92128
impl Default for UDTDataType {
@@ -95,27 +131,106 @@ impl Default for UDTDataType {
95131
}
96132
}
97133

98-
#[derive(Clone, Debug)]
134+
#[derive(Clone, Debug, PartialEq)]
135+
pub enum MapDataType {
136+
Untyped,
137+
Key(Arc<CassDataType>),
138+
KeyAndValue(Arc<CassDataType>, Arc<CassDataType>),
139+
}
140+
141+
#[derive(Clone, Debug, PartialEq)]
99142
pub enum CassDataType {
100143
Value(CassValueType),
101144
UDT(UDTDataType),
102145
List {
146+
// None stands for untyped list.
103147
typ: Option<Arc<CassDataType>>,
104148
frozen: bool,
105149
},
106150
Set {
151+
// None stands for untyped set.
107152
typ: Option<Arc<CassDataType>>,
108153
frozen: bool,
109154
},
110155
Map {
111-
key_type: Option<Arc<CassDataType>>,
112-
val_type: Option<Arc<CassDataType>>,
156+
typ: MapDataType,
113157
frozen: bool,
114158
},
159+
// Empty vector stands for untyped tuple.
115160
Tuple(Vec<Arc<CassDataType>>),
116161
Custom(String),
117162
}
118163

164+
impl CassDataType {
165+
/// Checks for equality during typechecks.
166+
///
167+
/// This takes into account the fact that tuples/collections may be untyped.
168+
pub fn typecheck_equals(&self, other: &CassDataType) -> bool {
169+
match self {
170+
CassDataType::Value(t) => *t == other.get_value_type(),
171+
CassDataType::UDT(udt) => match other {
172+
CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt),
173+
_ => false,
174+
},
175+
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match other {
176+
CassDataType::List { typ: other_typ, .. }
177+
| CassDataType::Set { typ: other_typ, .. } => {
178+
// If one of them is list, and the other is set, fail the typecheck.
179+
if self.get_value_type() != other.get_value_type() {
180+
return false;
181+
}
182+
match (typ, other_typ) {
183+
// One of them is untyped, skip the typecheck for subtype.
184+
(None, _) | (_, None) => true,
185+
(Some(typ), Some(other_typ)) => typ.typecheck_equals(other_typ),
186+
}
187+
}
188+
_ => false,
189+
},
190+
CassDataType::Map { typ: t, .. } => match other {
191+
CassDataType::Map { typ: t_other, .. } => match (t, t_other) {
192+
// See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218
193+
// In cpp-driver the types are held in a vector.
194+
// The logic is following:
195+
196+
// If either of vectors is empty, skip the typecheck.
197+
(MapDataType::Untyped, _) => true,
198+
(_, MapDataType::Untyped) => true,
199+
200+
// Otherwise, the vectors should have equal length and we perform the typecheck for subtypes.
201+
(MapDataType::Key(k), MapDataType::Key(k_other)) => k.typecheck_equals(k_other),
202+
(
203+
MapDataType::KeyAndValue(k, v),
204+
MapDataType::KeyAndValue(k_other, v_other),
205+
) => k.typecheck_equals(k_other) && v.typecheck_equals(v_other),
206+
_ => false,
207+
},
208+
_ => false,
209+
},
210+
CassDataType::Tuple(sub) => match other {
211+
CassDataType::Tuple(other_sub) => {
212+
// If either of tuples is untyped, skip the typecheck for subtypes.
213+
if sub.is_empty() || other_sub.is_empty() {
214+
return true;
215+
}
216+
217+
// If both are non-empty, check for subtypes equality.
218+
if sub.len() != other_sub.len() {
219+
return false;
220+
}
221+
sub.iter()
222+
.zip(other_sub.iter())
223+
.all(|(typ, other_typ)| typ.typecheck_equals(other_typ))
224+
}
225+
_ => false,
226+
},
227+
CassDataType::Custom(_) => {
228+
unimplemented!("Cpp-rust-driver does not support custom types!")
229+
}
230+
}
231+
}
232+
}
233+
119234
impl From<NativeType> for CassValueType {
120235
fn from(native_type: NativeType) -> CassValueType {
121236
match native_type {
@@ -160,16 +275,18 @@ pub fn get_column_type_from_cql_type(
160275
frozen: *frozen,
161276
},
162277
CollectionType::Map(key, value) => CassDataType::Map {
163-
key_type: Some(Arc::new(get_column_type_from_cql_type(
164-
key,
165-
user_defined_types,
166-
keyspace_name,
167-
))),
168-
val_type: Some(Arc::new(get_column_type_from_cql_type(
169-
value,
170-
user_defined_types,
171-
keyspace_name,
172-
))),
278+
typ: MapDataType::KeyAndValue(
279+
Arc::new(get_column_type_from_cql_type(
280+
key,
281+
user_defined_types,
282+
keyspace_name,
283+
)),
284+
Arc::new(get_column_type_from_cql_type(
285+
value,
286+
user_defined_types,
287+
keyspace_name,
288+
)),
289+
),
173290
frozen: *frozen,
174291
},
175292
CollectionType::Set(set) => CassDataType::Set {
@@ -222,10 +339,19 @@ impl CassDataType {
222339
}
223340
}
224341
CassDataType::Map {
225-
key_type, val_type, ..
342+
typ: MapDataType::Untyped,
343+
..
344+
} => None,
345+
CassDataType::Map {
346+
typ: MapDataType::Key(k),
347+
..
348+
} => (index == 0).then_some(k),
349+
CassDataType::Map {
350+
typ: MapDataType::KeyAndValue(k, v),
351+
..
226352
} => match index {
227-
0 => key_type.as_ref(),
228-
1 => val_type.as_ref(),
353+
0 => Some(k),
354+
1 => Some(v),
229355
_ => None,
230356
},
231357
CassDataType::Tuple(v) => v.get(index),
@@ -243,17 +369,28 @@ impl CassDataType {
243369
}
244370
},
245371
CassDataType::Map {
246-
key_type, val_type, ..
372+
typ: MapDataType::KeyAndValue(_, _),
373+
..
374+
} => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS),
375+
CassDataType::Map {
376+
typ: MapDataType::Key(k),
377+
frozen,
247378
} => {
248-
if key_type.is_some() && val_type.is_some() {
249-
Err(CassError::CASS_ERROR_LIB_BAD_PARAMS)
250-
} else if key_type.is_none() {
251-
*key_type = Some(sub_type);
252-
Ok(())
253-
} else {
254-
*val_type = Some(sub_type);
255-
Ok(())
256-
}
379+
*self = CassDataType::Map {
380+
typ: MapDataType::KeyAndValue(k.clone(), sub_type),
381+
frozen: *frozen,
382+
};
383+
Ok(())
384+
}
385+
CassDataType::Map {
386+
typ: MapDataType::Untyped,
387+
frozen,
388+
} => {
389+
*self = CassDataType::Map {
390+
typ: MapDataType::Key(sub_type),
391+
frozen: *frozen,
392+
};
393+
Ok(())
257394
}
258395
CassDataType::Tuple(types) => {
259396
types.push(sub_type);
@@ -305,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
305442
frozen: false,
306443
},
307444
ColumnType::Map(key, value) => CassDataType::Map {
308-
key_type: Some(Arc::new(get_column_type(key.as_ref()))),
309-
val_type: Some(Arc::new(get_column_type(value.as_ref()))),
445+
typ: MapDataType::KeyAndValue(
446+
Arc::new(get_column_type(key.as_ref())),
447+
Arc::new(get_column_type(value.as_ref())),
448+
),
310449
frozen: false,
311450
},
312451
ColumnType::Set(boxed_type) => CassDataType::Set {
@@ -357,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const
357496
},
358497
CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()),
359498
CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map {
360-
key_type: None,
361-
val_type: None,
499+
typ: MapDataType::Untyped,
362500
frozen: false,
363501
},
364502
CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()),
@@ -555,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat
555693
CassDataType::Value(..) => 0,
556694
CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t,
557695
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t,
558-
CassDataType::Map {
559-
key_type, val_type, ..
560-
} => key_type.is_some() as size_t + val_type.is_some() as size_t,
696+
CassDataType::Map { typ, .. } => match typ {
697+
MapDataType::Untyped => 0,
698+
MapDataType::Key(_) => 1,
699+
MapDataType::KeyAndValue(_, _) => 2,
700+
},
561701
CassDataType::Tuple(v) => v.len() as size_t,
562702
CassDataType::Custom(..) => 0,
563703
}

0 commit comments

Comments
 (0)