Skip to content

Commit a40f50a

Browse files
authored
Fix scalar eq, cmp, hash, etc. (#2091)
Previously we were being a little bit lazy and delegating to ScalarValue. We've been bitten by this in the past! It is never safe to delegate to ScalarValue since it doesn't have enough information to know what "equals" means, i.e. it's missing the surrounding scalar's DType. This PR removes PartialEq, PartialOrd, and Hash from ScalarValue, in favor of forcing us to wrap the value into a Scalar and then compare each dtype-specific scalar type.
1 parent ab5e71e commit a40f50a

File tree

15 files changed

+261
-53
lines changed

15 files changed

+261
-53
lines changed

pyvortex/src/scalar.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use pyo3::types::PyDict;
99
use vortex::buffer::{BufferString, ByteBuffer};
1010
use vortex::dtype::half::f16;
1111
use vortex::dtype::{DType, PType};
12+
use vortex::error::VortexExpect;
1213
use vortex::scalar::{ListScalar, Scalar, StructScalar};
1314

1415
pub fn scalar_into_py(py: Python, x: Scalar, copy_into_python: bool) -> PyResult<PyObject> {
@@ -190,6 +191,8 @@ impl PyVortexList {
190191
fn to_python_list(py: Python, scalar: ListScalar<'_>, recursive: bool) -> PyResult<PyObject> {
191192
Ok(scalar
192193
.elements()
194+
.vortex_expect("non-null")
195+
.into_iter()
193196
.map(|x| scalar_into_py(py, x, recursive))
194197
.collect::<PyResult<Vec<_>>>()?
195198
.into_py(py))

vortex-array/src/array/constant/canonical.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,12 @@ fn canonical_byte_view(
105105

106106
#[cfg(test)]
107107
mod tests {
108-
use vortex_dtype::{DType, Nullability};
108+
use vortex_dtype::half::f16;
109+
use vortex_dtype::{DType, Nullability, PType};
109110
use vortex_scalar::Scalar;
110111

111112
use crate::array::ConstantArray;
113+
use crate::canonical::IntoArrayVariant;
112114
use crate::compute::scalar_at;
113115
use crate::stats::{ArrayStatistics as _, StatsSet};
114116
use crate::{ArrayLen, IntoArrayData as _, IntoCanonical};
@@ -151,4 +153,17 @@ mod tests {
151153
assert_eq!(canonical_stats, StatsSet::constant(&scalar, 4));
152154
assert_eq!(canonical_stats, stats);
153155
}
156+
157+
#[test]
158+
fn test_canonicalize_scalar_values() {
159+
let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
160+
let scalar = Scalar::new(
161+
DType::Primitive(PType::F16, Nullability::NonNullable),
162+
Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
163+
);
164+
let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
165+
let canonical_const = const_array.into_primitive().unwrap();
166+
assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
167+
assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
168+
}
154169
}

vortex-array/src/builders/list.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,23 @@ where
4949
}
5050

5151
pub fn append_value(&mut self, value: ListScalar) -> VortexResult<()> {
52-
if value.is_null() {
53-
if self.nullability == Nullability::NonNullable {
54-
vortex_bail!("Cannot append null value to non-nullable list");
52+
match value.elements() {
53+
None => {
54+
if self.nullability == Nullability::NonNullable {
55+
vortex_bail!("Cannot append null value to non-nullable list");
56+
}
57+
self.append_null();
58+
Ok(())
5559
}
56-
self.append_null();
57-
Ok(())
58-
} else {
59-
for scalar in value.elements() {
60-
// TODO(joe): This is slow, we should be able to append multiple values at once,
61-
// or the list scalar should hold an ArrayData
62-
self.value_builder.append_scalar(&scalar)?;
60+
Some(elements) => {
61+
for scalar in elements {
62+
// TODO(joe): This is slow, we should be able to append multiple values at once,
63+
// or the list scalar should hold an ArrayData
64+
self.value_builder.append_scalar(&scalar)?;
65+
}
66+
self.validity.append_value(true);
67+
self.append_index(self.value_builder.len().as_())
6368
}
64-
self.validity.append_value(true);
65-
self.append_index(self.value_builder.len().as_())
6669
}
6770
}
6871

vortex-scalar/src/binary.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,27 @@ use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, Vort
77
use crate::value::{InnerScalarValue, ScalarValue};
88
use crate::Scalar;
99

10+
#[derive(Debug, Hash)]
1011
pub struct BinaryScalar<'a> {
1112
dtype: &'a DType,
1213
value: Option<ByteBuffer>,
1314
}
1415

16+
impl PartialEq for BinaryScalar<'_> {
17+
fn eq(&self, other: &Self) -> bool {
18+
self.dtype == other.dtype && self.value == other.value
19+
}
20+
}
21+
22+
impl Eq for BinaryScalar<'_> {}
23+
24+
/// Ord is not implemented since it's undefined for different nullability
25+
impl PartialOrd for BinaryScalar<'_> {
26+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
27+
self.value.partial_cmp(&other.value)
28+
}
29+
}
30+
1531
impl<'a> BinaryScalar<'a> {
1632
#[inline]
1733
pub fn dtype(&self) -> &'a DType {

vortex-scalar/src/bool.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
1+
use std::cmp::Ordering;
2+
13
use vortex_dtype::Nullability::NonNullable;
24
use vortex_dtype::{DType, Nullability};
35
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult};
46

57
use crate::value::ScalarValue;
68
use crate::{InnerScalarValue, Scalar};
79

10+
#[derive(Debug, Hash)]
811
pub struct BoolScalar<'a> {
912
dtype: &'a DType,
1013
value: Option<bool>,
1114
}
1215

16+
impl PartialEq for BoolScalar<'_> {
17+
fn eq(&self, other: &Self) -> bool {
18+
self.value == other.value
19+
}
20+
}
21+
22+
impl Eq for BoolScalar<'_> {}
23+
24+
impl PartialOrd for BoolScalar<'_> {
25+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
26+
if self.dtype != other.dtype {
27+
return None;
28+
}
29+
self.value.partial_cmp(&other.value)
30+
}
31+
}
32+
1333
impl<'a> BoolScalar<'a> {
1434
#[inline]
1535
pub fn dtype(&self) -> &'a DType {

vortex-scalar/src/display.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ impl Display for Scalar {
6161
}
6262
DType::List(..) => {
6363
let v = ListScalar::try_from(self).map_err(|_| std::fmt::Error)?;
64-
65-
if v.is_null() {
66-
write!(f, "null")
67-
} else {
68-
write!(f, "[{}]", v.elements().format(","))
64+
match v.elements() {
65+
None => write!(f, "null"),
66+
Some(elems) => {
67+
write!(f, "[{}]", elems.iter().format(","))
68+
}
6969
}
7070
}
7171
// Specialized handling for date/time/timestamp builtin extension types.

vortex-scalar/src/extension.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::hash::Hash;
12
use std::sync::Arc;
23

34
use vortex_dtype::{DType, ExtDType};
@@ -11,6 +12,30 @@ pub struct ExtScalar<'a> {
1112
value: &'a ScalarValue,
1213
}
1314

15+
impl PartialEq for ExtScalar<'_> {
16+
fn eq(&self, other: &Self) -> bool {
17+
self.ext_dtype == other.ext_dtype && self.storage() == other.storage()
18+
}
19+
}
20+
21+
impl Eq for ExtScalar<'_> {}
22+
23+
impl PartialOrd for ExtScalar<'_> {
24+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
25+
if self.ext_dtype != other.ext_dtype {
26+
return None;
27+
}
28+
self.storage().partial_cmp(&other.storage())
29+
}
30+
}
31+
32+
impl Hash for ExtScalar<'_> {
33+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
34+
self.ext_dtype.hash(state);
35+
self.storage().hash(state);
36+
}
37+
}
38+
1439
impl<'a> ExtScalar<'a> {
1540
pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
1641
let DType::Extension(ext_dtype) = dtype else {

vortex-scalar/src/lib.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::cmp::Ordering;
22
use std::hash::Hash;
3-
use std::mem::discriminant;
43
use std::sync::Arc;
54

65
pub use scalar_type::ScalarType;
@@ -42,8 +41,9 @@ use vortex_error::{vortex_bail, VortexExpect, VortexResult};
4241
/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
4342
/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
4443
///
45-
/// Note: [`PartialEq`] and [`PartialOrd`] are implemented only for an exact match of the scalar's
46-
/// dtype, including nullability.
44+
/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
45+
/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
46+
/// natural ordering of the scalar value.
4747
#[derive(Debug, Clone)]
4848
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
4949
pub struct Scalar {
@@ -203,28 +203,56 @@ impl Scalar {
203203

204204
impl PartialEq for Scalar {
205205
fn eq(&self, other: &Self) -> bool {
206-
self.dtype == other.dtype && self.value.0 == other.value.0
206+
if self.dtype != other.dtype {
207+
return false;
208+
}
209+
210+
match self.dtype() {
211+
DType::Null => true,
212+
DType::Bool(_) => self.as_bool() == other.as_bool(),
213+
DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
214+
DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
215+
DType::Binary(_) => self.as_binary() == other.as_binary(),
216+
DType::Struct(..) => self.as_struct() == other.as_struct(),
217+
DType::List(..) => self.as_list() == other.as_list(),
218+
DType::Extension(_) => self.as_extension() == other.as_extension(),
219+
}
207220
}
208221
}
209222

210223
impl Eq for Scalar {}
211224

212225
impl PartialOrd for Scalar {
213226
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
214-
// We check for DType equality, ignoring nullability, and allowing us to compare all
215-
// primitive types to all other primitive types.
216-
if discriminant(self.dtype()) == discriminant(other.dtype()) {
217-
self.value.0.partial_cmp(&other.value.0)
218-
} else {
219-
None
227+
if self.dtype() != other.dtype() {
228+
return None;
229+
}
230+
231+
match self.dtype() {
232+
DType::Null => Some(Ordering::Equal),
233+
DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
234+
DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
235+
DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
236+
DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
237+
DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
238+
DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
239+
DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
220240
}
221241
}
222242
}
223243

224244
impl Hash for Scalar {
225245
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
226-
discriminant(self.dtype()).hash(state);
227-
self.value.0.hash(state);
246+
match self.dtype() {
247+
DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
248+
DType::Bool(_) => self.as_bool().hash(state),
249+
DType::Primitive(..) => self.as_primitive().hash(state),
250+
DType::Utf8(_) => self.as_utf8().hash(state),
251+
DType::Binary(_) => self.as_binary().hash(state),
252+
DType::Struct(..) => self.as_struct().hash(state),
253+
DType::List(..) => self.as_list().hash(state),
254+
DType::Extension(_) => self.as_extension().hash(state),
255+
}
228256
}
229257
}
230258

vortex-scalar/src/list.rs

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
use std::hash::Hash;
12
use std::ops::Deref;
23
use std::sync::Arc;
34

45
use itertools::Itertools as _;
56
use vortex_dtype::{DType, Nullability};
6-
use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult};
7+
use vortex_error::{
8+
vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult,
9+
};
710

811
use crate::value::{InnerScalarValue, ScalarValue};
912
use crate::Scalar;
@@ -14,6 +17,34 @@ pub struct ListScalar<'a> {
1417
elements: Option<Arc<[ScalarValue]>>,
1518
}
1619

20+
impl PartialEq for ListScalar<'_> {
21+
fn eq(&self, other: &Self) -> bool {
22+
if self.dtype != other.dtype {
23+
return false;
24+
}
25+
self.elements() == other.elements()
26+
}
27+
}
28+
29+
impl Eq for ListScalar<'_> {}
30+
31+
/// Ord is not implemented since it's undefined for different DTypes
32+
impl PartialOrd for ListScalar<'_> {
33+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34+
if self.dtype() != other.dtype() {
35+
return None;
36+
}
37+
self.elements().partial_cmp(&other.elements())
38+
}
39+
}
40+
41+
impl Hash for ListScalar<'_> {
42+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
43+
self.dtype.hash(state);
44+
self.elements().hash(state);
45+
}
46+
}
47+
1748
impl<'a> ListScalar<'a> {
1849
#[inline]
1950
pub fn dtype(&self) -> &'a DType {
@@ -55,16 +86,13 @@ impl<'a> ListScalar<'a> {
5586
})
5687
}
5788

58-
pub fn elements(&self) -> impl Iterator<Item = Scalar> + '_ {
59-
self.elements
60-
.as_ref()
61-
.map(AsRef::as_ref)
62-
.unwrap_or_else(|| &[] as &[ScalarValue])
63-
.iter()
64-
.map(|e| Scalar {
65-
dtype: self.element_dtype(),
66-
value: e.clone(),
67-
})
89+
pub fn elements(&self) -> Option<Vec<Scalar>> {
90+
self.elements.as_ref().map(|elems| {
91+
elems
92+
.iter()
93+
.map(|e| Scalar::new(self.element_dtype(), e.clone()))
94+
.collect_vec()
95+
})
6896
}
6997

7098
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
@@ -143,7 +171,10 @@ impl<'a, T: for<'b> TryFrom<&'b Scalar, Error = VortexError>> TryFrom<&'a Scalar
143171
fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
144172
let value = ListScalar::try_from(value)?;
145173
let mut elems = vec![];
146-
for e in value.elements() {
174+
for e in value
175+
.elements()
176+
.ok_or_else(|| vortex_err!("Expected non-null list"))?
177+
{
147178
elems.push(T::try_from(&e)?);
148179
}
149180
Ok(elems)

0 commit comments

Comments
 (0)