Skip to content

Commit 91ca489

Browse files
committed
fix[array]: min_max return type non nullable
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 2f40787 commit 91ca489

File tree

7 files changed

+61
-47
lines changed

7 files changed

+61
-47
lines changed

encodings/sequence/src/compute/min_max.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
55
use vortex_array::register_kernel;
6+
use vortex_dtype::Nullability::NonNullable;
67
use vortex_error::VortexResult;
78
use vortex_scalar::Scalar;
89

@@ -19,8 +20,8 @@ impl MinMaxKernel for SequenceVTable {
1920
(last, base)
2021
};
2122
Ok(Some(MinMaxResult {
22-
min: Scalar::new(array.dtype().clone(), min.into()),
23-
max: Scalar::new(array.dtype().clone(), max.into()),
23+
min: Scalar::primitive_value(min, array.ptype(), NonNullable),
24+
max: Scalar::primitive_value(max, array.ptype(), NonNullable),
2425
}))
2526
}
2627
}

vortex-array/src/arrays/bool/compute/min_max.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
use std::ops::BitAnd;
55

6+
use Nullability::NonNullable;
7+
use vortex_dtype::Nullability;
68
use vortex_error::VortexResult;
79
use vortex_mask::Mask;
810
use vortex_scalar::Scalar;
@@ -30,15 +32,15 @@ impl MinMaxKernel for BoolVTable {
3032
let Some(slice) = true_slices.next() else {
3133
// all false
3234
return Ok(Some(MinMaxResult {
33-
min: Scalar::bool(false, array.dtype().nullability()),
34-
max: Scalar::bool(false, array.dtype().nullability()),
35+
min: Scalar::bool(false, NonNullable),
36+
max: Scalar::bool(false, NonNullable),
3537
}));
3638
};
3739
if slice.0 == 0 && slice.1 == array.len() {
3840
// all true
3941
return Ok(Some(MinMaxResult {
40-
min: Scalar::bool(true, array.dtype().nullability()),
41-
max: Scalar::bool(true, array.dtype().nullability()),
42+
min: Scalar::bool(true, NonNullable),
43+
max: Scalar::bool(true, NonNullable),
4244
}));
4345
};
4446

@@ -53,16 +55,16 @@ impl MinMaxKernel for BoolVTable {
5355
let Some(_) = false_slices.next() else {
5456
// In this case we don't have any false values which means we are all true and null
5557
return Ok(Some(MinMaxResult {
56-
min: Scalar::bool(true, array.dtype().nullability()),
57-
max: Scalar::bool(true, array.dtype().nullability()),
58+
min: Scalar::bool(true, NonNullable),
59+
max: Scalar::bool(true, NonNullable),
5860
}));
5961
};
6062
}
6163
}
6264

6365
Ok(Some(MinMaxResult {
64-
min: Scalar::bool(false, array.dtype().nullability()),
65-
max: Scalar::bool(true, array.dtype().nullability()),
66+
min: Scalar::bool(false, NonNullable),
67+
max: Scalar::bool(true, NonNullable),
6668
}))
6769
}
6870
}
@@ -71,47 +73,47 @@ register_kernel!(MinMaxKernelAdapter(BoolVTable).lift());
7173

7274
#[cfg(test)]
7375
mod tests {
74-
use vortex_dtype::{DType, Nullability};
76+
use Nullability::NonNullable;
77+
use vortex_dtype::Nullability;
7578
use vortex_scalar::Scalar;
7679

7780
use crate::arrays::BoolArray;
7881
use crate::compute::{MinMaxResult, min_max};
7982

8083
#[test]
8184
fn test_min_max_nulls() {
82-
let dtype = DType::Bool(Nullability::Nullable);
8385
assert_eq!(
8486
min_max(BoolArray::from_iter(vec![Some(true), Some(true), None, None]).as_ref())
8587
.unwrap(),
8688
Some(MinMaxResult {
87-
min: Scalar::new(dtype.clone(), true.into()),
88-
max: Scalar::new(dtype.clone(), true.into()),
89+
min: Scalar::bool(true, NonNullable),
90+
max: Scalar::bool(true, NonNullable),
8991
})
9092
);
9193

9294
assert_eq!(
9395
min_max(BoolArray::from_iter(vec![None, Some(true), Some(true)]).as_ref()).unwrap(),
9496
Some(MinMaxResult {
95-
min: Scalar::new(dtype.clone(), true.into()),
96-
max: Scalar::new(dtype.clone(), true.into()),
97+
min: Scalar::bool(true, NonNullable),
98+
max: Scalar::bool(true, NonNullable),
9799
})
98100
);
99101

100102
assert_eq!(
101103
min_max(BoolArray::from_iter(vec![None, Some(true), Some(true), None]).as_ref())
102104
.unwrap(),
103105
Some(MinMaxResult {
104-
min: Scalar::new(dtype.clone(), true.into()),
105-
max: Scalar::new(dtype.clone(), true.into()),
106+
min: Scalar::bool(true, NonNullable),
107+
max: Scalar::bool(true, NonNullable),
106108
})
107109
);
108110

109111
assert_eq!(
110112
min_max(BoolArray::from_iter(vec![Some(false), Some(false), None, None]).as_ref())
111113
.unwrap(),
112114
Some(MinMaxResult {
113-
min: Scalar::new(dtype.clone(), false.into()),
114-
max: Scalar::new(dtype, false.into()),
115+
min: Scalar::bool(false, NonNullable),
116+
max: Scalar::bool(false, NonNullable),
115117
})
116118
);
117119
}

vortex-array/src/arrays/decimal/compute/min_max.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,19 @@ fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>, dtype: &DType) -> O
4444
where
4545
T: Into<DecimalValue> + NativeDecimalType + Ord + Copy + 'a,
4646
{
47+
let non_nullable_dtype = dtype.as_nonnullable();
4748
match iter.minmax_by(|a, b| a.cmp(b)) {
4849
itertools::MinMaxResult::NoElements => None,
4950
itertools::MinMaxResult::OneElement(&x) => {
50-
let scalar = Scalar::new(dtype.clone(), ScalarValue::from(x.into()));
51+
let scalar = Scalar::new(non_nullable_dtype, ScalarValue::from(x.into()));
5152
Some(MinMaxResult {
5253
min: scalar.clone(),
5354
max: scalar,
5455
})
5556
}
5657
itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
57-
min: Scalar::new(dtype.clone(), ScalarValue::from(min.into())),
58-
max: Scalar::new(dtype.clone(), ScalarValue::from(max.into())),
58+
min: Scalar::new(non_nullable_dtype.clone(), ScalarValue::from(min.into())),
59+
max: Scalar::new(non_nullable_dtype, ScalarValue::from(max.into())),
5960
}),
6061
}
6162
}
@@ -80,13 +81,14 @@ mod tests {
8081

8182
let min_max = min_max(decimal.as_ref()).unwrap();
8283

84+
let non_nullable_dtype = decimal.dtype().as_nonnullable();
8385
let expected = MinMaxResult {
8486
min: Scalar::new(
85-
decimal.dtype().clone(),
87+
non_nullable_dtype.clone(),
8688
ScalarValue::from(DecimalValue::from(100i32)),
8789
),
8890
max: Scalar::new(
89-
decimal.dtype().clone(),
91+
non_nullable_dtype,
9092
ScalarValue::from(DecimalValue::from(200i32)),
9193
),
9294
};

vortex-array/src/arrays/extension/compute/min_max.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
6+
use vortex_dtype::Nullability;
47
use vortex_error::VortexResult;
58
use vortex_scalar::Scalar;
69

@@ -10,10 +13,12 @@ use crate::register_kernel;
1013

1114
impl MinMaxKernel for ExtensionVTable {
1215
fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
16+
let non_nullable_ext_dtype =
17+
Arc::new(array.ext_dtype().with_nullability(Nullability::NonNullable));
1318
Ok(
1419
compute::min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
15-
min: Scalar::extension(array.ext_dtype().clone(), min),
16-
max: Scalar::extension(array.ext_dtype().clone(), max),
20+
min: Scalar::extension(non_nullable_ext_dtype.clone(), min),
21+
max: Scalar::extension(non_nullable_ext_dtype, max),
1722
}),
1823
)
1924
}

vortex-array/src/arrays/primitive/compute/min_max.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5+
use vortex_dtype::Nullability::NonNullable;
56
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
67
use vortex_error::VortexResult;
78
use vortex_mask::Mask;
8-
use vortex_scalar::{Scalar, ScalarValue};
9+
use vortex_scalar::{PValue, Scalar, ScalarValue};
910

1011
use crate::arrays::{PrimitiveArray, PrimitiveVTable};
1112
use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
@@ -24,7 +25,8 @@ register_kernel!(MinMaxKernelAdapter(PrimitiveVTable).lift());
2425
#[inline]
2526
fn compute_min_max_with_validity<T>(array: &PrimitiveArray) -> VortexResult<Option<MinMaxResult>>
2627
where
27-
T: Into<ScalarValue> + NativePType,
28+
T: NativePType,
29+
PValue: From<T>,
2830
{
2931
Ok(match array.validity_mask() {
3032
Mask::AllTrue(_) => compute_min_max(array.as_slice::<T>().iter(), array.dtype()),
@@ -42,7 +44,8 @@ where
4244

4345
fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>, dtype: &DType) -> Option<MinMaxResult>
4446
where
45-
T: Into<ScalarValue> + NativePType,
47+
T: NativePType,
48+
PValue: From<T>,
4649
{
4750
// `total_compare` function provides a total ordering (even for NaN values).
4851
// However, we exclude NaNs from min max as they're not useful for any purpose where min/max would be used
@@ -52,15 +55,15 @@ where
5255
{
5356
itertools::MinMaxResult::NoElements => None,
5457
itertools::MinMaxResult::OneElement(&x) => {
55-
let scalar = Scalar::new(dtype.clone(), x.into());
58+
let scalar = Scalar::primitive(x, NonNullable);
5659
Some(MinMaxResult {
5760
min: scalar.clone(),
5861
max: scalar,
5962
})
6063
}
6164
itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
62-
min: Scalar::new(dtype.clone(), min.into()),
63-
max: Scalar::new(dtype.clone(), max.into()),
65+
min: Scalar::primitive(min, NonNullable),
66+
max: Scalar::primitive(max, NonNullable),
6467
}),
6568
}
6669
}

vortex-array/src/arrays/varbin/compute/min_max.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use itertools::Itertools;
55
use vortex_dtype::DType;
6+
use vortex_dtype::Nullability::NonNullable;
67
use vortex_error::{VortexResult, vortex_panic};
78
use vortex_scalar::Scalar;
89

@@ -45,12 +46,12 @@ pub(crate) fn varbin_compute_min_max<T: ArrayAccessor<[u8]>>(
4546
/// Helper function to make sure that min/max has the right [`Scalar`] type.
4647
fn make_scalar(dtype: &DType, value: &[u8]) -> Scalar {
4748
match dtype {
48-
DType::Binary(_) => Scalar::new(dtype.clone(), value.into()),
49+
DType::Binary(_) => Scalar::binary(value.to_vec(), NonNullable),
4950
DType::Utf8(_) => {
5051
// SAFETY: We only call `compute_min_max` within `varbin/`, in which we always validate
5152
// the arrays, and we always pass `array.dtype()` in as the `dtype` argument.
5253
let value = unsafe { str::from_utf8_unchecked(value) };
53-
Scalar::new(dtype.clone(), value.into())
54+
Scalar::utf8(value, NonNullable)
5455
}
5556
_ => vortex_panic!("cannot make Scalar from bytes with dtype {dtype}"),
5657
}
@@ -60,7 +61,7 @@ fn make_scalar(dtype: &DType, value: &[u8]) -> Scalar {
6061
mod tests {
6162
use vortex_buffer::BufferString;
6263
use vortex_dtype::DType::Utf8;
63-
use vortex_dtype::Nullability::Nullable;
64+
use vortex_dtype::Nullability::{NonNullable, Nullable};
6465
use vortex_scalar::Scalar;
6566

6667
use crate::arrays::VarBinArray;
@@ -83,14 +84,14 @@ mod tests {
8384
assert_eq!(
8485
min,
8586
Scalar::new(
86-
Utf8(Nullable),
87+
Utf8(NonNullable),
8788
BufferString::from("hello world".to_string()).into(),
8889
)
8990
);
9091
assert_eq!(
9192
max,
9293
Scalar::new(
93-
Utf8(Nullable),
94+
Utf8(NonNullable),
9495
BufferString::from("hello world this is a long string".to_string()).into()
9596
)
9697
);

vortex-array/src/compute/min_max.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,8 @@ fn min_max_impl(
136136
return Ok(None);
137137
}
138138

139-
if let Some(array) = array.as_opt::<ConstantVTable>()
140-
&& !array.scalar().is_null()
141-
{
142-
return Ok(Some(MinMaxResult {
143-
min: array.scalar().clone(),
144-
max: array.scalar().clone(),
145-
}));
139+
if let Some(array) = array.as_opt::<ConstantVTable>() {
140+
return ConstantVTable.min_max(array);
146141
}
147142

148143
let min = array
@@ -155,7 +150,11 @@ fn min_max_impl(
155150
.and_then(Precision::as_exact);
156151

157152
if let Some((min, max)) = min.zip(max) {
158-
return Ok(Some(MinMaxResult { min, max }));
153+
let non_nullable_dtype = array.dtype().as_nonnullable();
154+
return Ok(Some(MinMaxResult {
155+
min: min.cast(&non_nullable_dtype)?,
156+
max: max.cast(&non_nullable_dtype)?,
157+
}));
159158
}
160159

161160
let args = InvocationArgs {
@@ -202,10 +201,11 @@ impl<V: VTable + MinMaxKernel> Kernel for MinMaxKernelAdapter<V> {
202201
let Some(array) = inputs.array.as_opt::<V>() else {
203202
return Ok(None);
204203
};
204+
let non_nullable_dtype = array.dtype().as_nonnullable();
205205
let dtype = DType::Struct(
206206
StructFields::new(
207207
["min", "max"].into(),
208-
vec![array.dtype().clone(), array.dtype().clone()],
208+
vec![non_nullable_dtype.clone(), non_nullable_dtype],
209209
),
210210
Nullability::Nullable,
211211
);

0 commit comments

Comments
 (0)