@@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs"));
15
15
include ! ( concat!( env!( "OUT_DIR" ) , "/cppdriver_data_query_error.rs" ) ) ;
16
16
include ! ( concat!( env!( "OUT_DIR" ) , "/cppdriver_batch_types.rs" ) ) ;
17
17
18
- #[ derive( Clone , Debug ) ]
18
+ #[ derive( Clone , Debug , PartialEq ) ]
19
19
pub struct UDTDataType {
20
20
// Vec to preserve the order of types
21
21
pub field_types : Vec < ( String , Arc < CassDataType > ) > ,
@@ -87,6 +87,42 @@ impl UDTDataType {
87
87
pub fn get_field_by_index ( & self , index : usize ) -> Option < & Arc < CassDataType > > {
88
88
self . field_types . get ( index) . map ( |( _, b) | b)
89
89
}
90
+
91
+ fn typecheck_equals ( & self , other : & UDTDataType ) -> bool {
92
+ // See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386
93
+
94
+ if !any_string_empty_or_both_equal ( & self . keyspace , & other. keyspace ) {
95
+ return false ;
96
+ }
97
+ if !any_string_empty_or_both_equal ( & self . name , & other. name ) {
98
+ return false ;
99
+ }
100
+
101
+ // A comment from cpp-driver:
102
+ //// UDT's can be considered equal as long as the mutual first fields shared
103
+ //// between them are equal. UDT's are append only as far as fields go, so a
104
+ //// newer 'version' of the UDT data type after a schema change event should be
105
+ //// treated as equivalent in this scenario, by simply looking at the first N
106
+ //// mutual fields they should share.
107
+ //
108
+ // Iterator returned from zip() is perfect for checking the first mutual fields.
109
+ for ( field, other_field) in self . field_types . iter ( ) . zip ( other. field_types . iter ( ) ) {
110
+ // Compare field names.
111
+ if field. 0 != other_field. 0 {
112
+ return false ;
113
+ }
114
+ // Compare field types.
115
+ if !field. 1 . typecheck_equals ( & other_field. 1 ) {
116
+ return false ;
117
+ }
118
+ }
119
+
120
+ true
121
+ }
122
+ }
123
+
124
+ fn any_string_empty_or_both_equal ( s1 : & str , s2 : & str ) -> bool {
125
+ s1. is_empty ( ) || s2. is_empty ( ) || s1 == s2
90
126
}
91
127
92
128
impl Default for UDTDataType {
@@ -95,27 +131,106 @@ impl Default for UDTDataType {
95
131
}
96
132
}
97
133
98
- #[ derive( Clone , Debug ) ]
134
+ #[ derive( Clone , Debug , PartialEq ) ]
135
+ pub enum MapDataType {
136
+ Untyped ,
137
+ Key ( Arc < CassDataType > ) ,
138
+ KeyAndValue ( Arc < CassDataType > , Arc < CassDataType > ) ,
139
+ }
140
+
141
+ #[ derive( Clone , Debug , PartialEq ) ]
99
142
pub enum CassDataType {
100
143
Value ( CassValueType ) ,
101
144
UDT ( UDTDataType ) ,
102
145
List {
146
+ // None stands for untyped list.
103
147
typ : Option < Arc < CassDataType > > ,
104
148
frozen : bool ,
105
149
} ,
106
150
Set {
151
+ // None stands for untyped set.
107
152
typ : Option < Arc < CassDataType > > ,
108
153
frozen : bool ,
109
154
} ,
110
155
Map {
111
- key_type : Option < Arc < CassDataType > > ,
112
- val_type : Option < Arc < CassDataType > > ,
156
+ typ : MapDataType ,
113
157
frozen : bool ,
114
158
} ,
159
+ // Empty vector stands for untyped tuple.
115
160
Tuple ( Vec < Arc < CassDataType > > ) ,
116
161
Custom ( String ) ,
117
162
}
118
163
164
+ impl CassDataType {
165
+ /// Checks for equality during typechecks.
166
+ ///
167
+ /// This takes into account the fact that tuples/collections may be untyped.
168
+ pub fn typecheck_equals ( & self , other : & CassDataType ) -> bool {
169
+ match self {
170
+ CassDataType :: Value ( t) => * t == other. get_value_type ( ) ,
171
+ CassDataType :: UDT ( udt) => match other {
172
+ CassDataType :: UDT ( other_udt) => udt. typecheck_equals ( other_udt) ,
173
+ _ => false ,
174
+ } ,
175
+ CassDataType :: List { typ, .. } | CassDataType :: Set { typ, .. } => match other {
176
+ CassDataType :: List { typ : other_typ, .. }
177
+ | CassDataType :: Set { typ : other_typ, .. } => {
178
+ // If one of them is list, and the other is set, fail the typecheck.
179
+ if self . get_value_type ( ) != other. get_value_type ( ) {
180
+ return false ;
181
+ }
182
+ match ( typ, other_typ) {
183
+ // One of them is untyped, skip the typecheck for subtype.
184
+ ( None , _) | ( _, None ) => true ,
185
+ ( Some ( typ) , Some ( other_typ) ) => typ. typecheck_equals ( other_typ) ,
186
+ }
187
+ }
188
+ _ => false ,
189
+ } ,
190
+ CassDataType :: Map { typ : t, .. } => match other {
191
+ CassDataType :: Map { typ : t_other, .. } => match ( t, t_other) {
192
+ // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218
193
+ // In cpp-driver the types are held in a vector.
194
+ // The logic is following:
195
+
196
+ // If either of vectors is empty, skip the typecheck.
197
+ ( MapDataType :: Untyped , _) => true ,
198
+ ( _, MapDataType :: Untyped ) => true ,
199
+
200
+ // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes.
201
+ ( MapDataType :: Key ( k) , MapDataType :: Key ( k_other) ) => k. typecheck_equals ( k_other) ,
202
+ (
203
+ MapDataType :: KeyAndValue ( k, v) ,
204
+ MapDataType :: KeyAndValue ( k_other, v_other) ,
205
+ ) => k. typecheck_equals ( k_other) && v. typecheck_equals ( v_other) ,
206
+ _ => false ,
207
+ } ,
208
+ _ => false ,
209
+ } ,
210
+ CassDataType :: Tuple ( sub) => match other {
211
+ CassDataType :: Tuple ( other_sub) => {
212
+ // If either of tuples is untyped, skip the typecheck for subtypes.
213
+ if sub. is_empty ( ) || other_sub. is_empty ( ) {
214
+ return true ;
215
+ }
216
+
217
+ // If both are non-empty, check for subtypes equality.
218
+ if sub. len ( ) != other_sub. len ( ) {
219
+ return false ;
220
+ }
221
+ sub. iter ( )
222
+ . zip ( other_sub. iter ( ) )
223
+ . all ( |( typ, other_typ) | typ. typecheck_equals ( other_typ) )
224
+ }
225
+ _ => false ,
226
+ } ,
227
+ CassDataType :: Custom ( _) => {
228
+ unimplemented ! ( "Cpp-rust-driver does not support custom types!" )
229
+ }
230
+ }
231
+ }
232
+ }
233
+
119
234
impl From < NativeType > for CassValueType {
120
235
fn from ( native_type : NativeType ) -> CassValueType {
121
236
match native_type {
@@ -160,16 +275,18 @@ pub fn get_column_type_from_cql_type(
160
275
frozen : * frozen,
161
276
} ,
162
277
CollectionType :: Map ( key, value) => CassDataType :: Map {
163
- key_type : Some ( Arc :: new ( get_column_type_from_cql_type (
164
- key,
165
- user_defined_types,
166
- keyspace_name,
167
- ) ) ) ,
168
- val_type : Some ( Arc :: new ( get_column_type_from_cql_type (
169
- value,
170
- user_defined_types,
171
- keyspace_name,
172
- ) ) ) ,
278
+ typ : MapDataType :: KeyAndValue (
279
+ Arc :: new ( get_column_type_from_cql_type (
280
+ key,
281
+ user_defined_types,
282
+ keyspace_name,
283
+ ) ) ,
284
+ Arc :: new ( get_column_type_from_cql_type (
285
+ value,
286
+ user_defined_types,
287
+ keyspace_name,
288
+ ) ) ,
289
+ ) ,
173
290
frozen : * frozen,
174
291
} ,
175
292
CollectionType :: Set ( set) => CassDataType :: Set {
@@ -222,10 +339,19 @@ impl CassDataType {
222
339
}
223
340
}
224
341
CassDataType :: Map {
225
- key_type, val_type, ..
342
+ typ : MapDataType :: Untyped ,
343
+ ..
344
+ } => None ,
345
+ CassDataType :: Map {
346
+ typ : MapDataType :: Key ( k) ,
347
+ ..
348
+ } => ( index == 0 ) . then_some ( k) ,
349
+ CassDataType :: Map {
350
+ typ : MapDataType :: KeyAndValue ( k, v) ,
351
+ ..
226
352
} => match index {
227
- 0 => key_type . as_ref ( ) ,
228
- 1 => val_type . as_ref ( ) ,
353
+ 0 => Some ( k ) ,
354
+ 1 => Some ( v ) ,
229
355
_ => None ,
230
356
} ,
231
357
CassDataType :: Tuple ( v) => v. get ( index) ,
@@ -243,17 +369,28 @@ impl CassDataType {
243
369
}
244
370
} ,
245
371
CassDataType :: Map {
246
- key_type, val_type, ..
372
+ typ : MapDataType :: KeyAndValue ( _, _) ,
373
+ ..
374
+ } => Err ( CassError :: CASS_ERROR_LIB_BAD_PARAMS ) ,
375
+ CassDataType :: Map {
376
+ typ : MapDataType :: Key ( k) ,
377
+ frozen,
247
378
} => {
248
- if key_type. is_some ( ) && val_type. is_some ( ) {
249
- Err ( CassError :: CASS_ERROR_LIB_BAD_PARAMS )
250
- } else if key_type. is_none ( ) {
251
- * key_type = Some ( sub_type) ;
252
- Ok ( ( ) )
253
- } else {
254
- * val_type = Some ( sub_type) ;
255
- Ok ( ( ) )
256
- }
379
+ * self = CassDataType :: Map {
380
+ typ : MapDataType :: KeyAndValue ( k. clone ( ) , sub_type) ,
381
+ frozen : * frozen,
382
+ } ;
383
+ Ok ( ( ) )
384
+ }
385
+ CassDataType :: Map {
386
+ typ : MapDataType :: Untyped ,
387
+ frozen,
388
+ } => {
389
+ * self = CassDataType :: Map {
390
+ typ : MapDataType :: Key ( sub_type) ,
391
+ frozen : * frozen,
392
+ } ;
393
+ Ok ( ( ) )
257
394
}
258
395
CassDataType :: Tuple ( types) => {
259
396
types. push ( sub_type) ;
@@ -305,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
305
442
frozen : false ,
306
443
} ,
307
444
ColumnType :: Map ( key, value) => CassDataType :: Map {
308
- key_type : Some ( Arc :: new ( get_column_type ( key. as_ref ( ) ) ) ) ,
309
- val_type : Some ( Arc :: new ( get_column_type ( value. as_ref ( ) ) ) ) ,
445
+ typ : MapDataType :: KeyAndValue (
446
+ Arc :: new ( get_column_type ( key. as_ref ( ) ) ) ,
447
+ Arc :: new ( get_column_type ( value. as_ref ( ) ) ) ,
448
+ ) ,
310
449
frozen : false ,
311
450
} ,
312
451
ColumnType :: Set ( boxed_type) => CassDataType :: Set {
@@ -357,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const
357
496
} ,
358
497
CassValueType :: CASS_VALUE_TYPE_TUPLE => CassDataType :: Tuple ( Vec :: new ( ) ) ,
359
498
CassValueType :: CASS_VALUE_TYPE_MAP => CassDataType :: Map {
360
- key_type : None ,
361
- val_type : None ,
499
+ typ : MapDataType :: Untyped ,
362
500
frozen : false ,
363
501
} ,
364
502
CassValueType :: CASS_VALUE_TYPE_UDT => CassDataType :: UDT ( UDTDataType :: new ( ) ) ,
@@ -555,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat
555
693
CassDataType :: Value ( ..) => 0 ,
556
694
CassDataType :: UDT ( udt_data_type) => udt_data_type. field_types . len ( ) as size_t ,
557
695
CassDataType :: List { typ, .. } | CassDataType :: Set { typ, .. } => typ. is_some ( ) as size_t ,
558
- CassDataType :: Map {
559
- key_type, val_type, ..
560
- } => key_type. is_some ( ) as size_t + val_type. is_some ( ) as size_t ,
696
+ CassDataType :: Map { typ, .. } => match typ {
697
+ MapDataType :: Untyped => 0 ,
698
+ MapDataType :: Key ( _) => 1 ,
699
+ MapDataType :: KeyAndValue ( _, _) => 2 ,
700
+ } ,
561
701
CassDataType :: Tuple ( v) => v. len ( ) as size_t ,
562
702
CassDataType :: Custom ( ..) => 0 ,
563
703
}
0 commit comments