|
1 | 1 | use std::cmp::Ordering; |
2 | 2 | use std::hash::Hash; |
3 | | -use std::mem::discriminant; |
4 | 3 | use std::sync::Arc; |
5 | 4 |
|
6 | 5 | pub use scalar_type::ScalarType; |
@@ -42,8 +41,9 @@ use vortex_error::{vortex_bail, VortexExpect, VortexResult}; |
42 | 41 | /// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers |
43 | 42 | /// for example [`BoolScalar`], [`PrimitiveScalar`], etc. |
44 | 43 | /// |
45 | | -/// Note: [`PartialEq`] and [`PartialOrd`] are implemented only for an exact match of the scalar's |
46 | | -/// dtype, including nullability. |
| 44 | +/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype, |
| 45 | +/// including nullability. When the DType does match, ordering is nulls first (lowest), then the |
| 46 | +/// natural ordering of the scalar value. |
47 | 47 | #[derive(Debug, Clone)] |
48 | 48 | #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] |
49 | 49 | pub struct Scalar { |
@@ -203,28 +203,56 @@ impl Scalar { |
203 | 203 |
|
204 | 204 | impl PartialEq for Scalar { |
205 | 205 | fn eq(&self, other: &Self) -> bool { |
206 | | - self.dtype == other.dtype && self.value.0 == other.value.0 |
| 206 | + if self.dtype != other.dtype { |
| 207 | + return false; |
| 208 | + } |
| 209 | + |
| 210 | + match self.dtype() { |
| 211 | + DType::Null => true, |
| 212 | + DType::Bool(_) => self.as_bool() == other.as_bool(), |
| 213 | + DType::Primitive(..) => self.as_primitive() == other.as_primitive(), |
| 214 | + DType::Utf8(_) => self.as_utf8() == other.as_utf8(), |
| 215 | + DType::Binary(_) => self.as_binary() == other.as_binary(), |
| 216 | + DType::Struct(..) => self.as_struct() == other.as_struct(), |
| 217 | + DType::List(..) => self.as_list() == other.as_list(), |
| 218 | + DType::Extension(_) => self.as_extension() == other.as_extension(), |
| 219 | + } |
207 | 220 | } |
208 | 221 | } |
209 | 222 |
|
210 | 223 | impl Eq for Scalar {} |
211 | 224 |
|
212 | 225 | impl PartialOrd for Scalar { |
213 | 226 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
214 | | - // We check for DType equality, ignoring nullability, and allowing us to compare all |
215 | | - // primitive types to all other primitive types. |
216 | | - if discriminant(self.dtype()) == discriminant(other.dtype()) { |
217 | | - self.value.0.partial_cmp(&other.value.0) |
218 | | - } else { |
219 | | - None |
| 227 | + if self.dtype() != other.dtype() { |
| 228 | + return None; |
| 229 | + } |
| 230 | + |
| 231 | + match self.dtype() { |
| 232 | + DType::Null => Some(Ordering::Equal), |
| 233 | + DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()), |
| 234 | + DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()), |
| 235 | + DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()), |
| 236 | + DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()), |
| 237 | + DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()), |
| 238 | + DType::List(..) => self.as_list().partial_cmp(&other.as_list()), |
| 239 | + DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()), |
220 | 240 | } |
221 | 241 | } |
222 | 242 | } |
223 | 243 |
|
224 | 244 | impl Hash for Scalar { |
225 | 245 | fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
226 | | - discriminant(self.dtype()).hash(state); |
227 | | - self.value.0.hash(state); |
| 246 | + match self.dtype() { |
| 247 | + DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value |
| 248 | + DType::Bool(_) => self.as_bool().hash(state), |
| 249 | + DType::Primitive(..) => self.as_primitive().hash(state), |
| 250 | + DType::Utf8(_) => self.as_utf8().hash(state), |
| 251 | + DType::Binary(_) => self.as_binary().hash(state), |
| 252 | + DType::Struct(..) => self.as_struct().hash(state), |
| 253 | + DType::List(..) => self.as_list().hash(state), |
| 254 | + DType::Extension(_) => self.as_extension().hash(state), |
| 255 | + } |
228 | 256 | } |
229 | 257 | } |
230 | 258 |
|
|
0 commit comments