Skip to content

Commit 6a54fbe

Browse files
fix[array]: min_max return type non nullable (#5202)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 14841c0 commit 6a54fbe

File tree

7 files changed

+76
-57
lines changed

7 files changed

+76
-57
lines changed

encodings/sequence/src/compute/min_max.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use vortex_array::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
55
use vortex_array::register_kernel;
6+
use vortex_dtype::Nullability::NonNullable;
67
use vortex_error::VortexResult;
78
use vortex_scalar::Scalar;
89

@@ -19,8 +20,8 @@ impl MinMaxKernel for SequenceVTable {
1920
(last, base)
2021
};
2122
Ok(Some(MinMaxResult {
22-
min: Scalar::new(array.dtype().clone(), min.into()),
23-
max: Scalar::new(array.dtype().clone(), max.into()),
23+
min: Scalar::primitive_value(min, array.ptype(), NonNullable),
24+
max: Scalar::primitive_value(max, array.ptype(), NonNullable),
2425
}))
2526
}
2627
}

vortex-array/src/arrays/bool/compute/min_max.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
use std::ops::BitAnd;
55

6+
use Nullability::NonNullable;
7+
use vortex_dtype::Nullability;
68
use vortex_error::VortexResult;
79
use vortex_mask::Mask;
810
use vortex_scalar::Scalar;
@@ -30,15 +32,15 @@ impl MinMaxKernel for BoolVTable {
3032
let Some(slice) = true_slices.next() else {
3133
// all false
3234
return Ok(Some(MinMaxResult {
33-
min: Scalar::bool(false, array.dtype().nullability()),
34-
max: Scalar::bool(false, array.dtype().nullability()),
35+
min: Scalar::bool(false, NonNullable),
36+
max: Scalar::bool(false, NonNullable),
3537
}));
3638
};
3739
if slice.0 == 0 && slice.1 == array.len() {
3840
// all true
3941
return Ok(Some(MinMaxResult {
40-
min: Scalar::bool(true, array.dtype().nullability()),
41-
max: Scalar::bool(true, array.dtype().nullability()),
42+
min: Scalar::bool(true, NonNullable),
43+
max: Scalar::bool(true, NonNullable),
4244
}));
4345
};
4446

@@ -53,16 +55,16 @@ impl MinMaxKernel for BoolVTable {
5355
let Some(_) = false_slices.next() else {
5456
// In this case we don't have any false values which means we are all true and null
5557
return Ok(Some(MinMaxResult {
56-
min: Scalar::bool(true, array.dtype().nullability()),
57-
max: Scalar::bool(true, array.dtype().nullability()),
58+
min: Scalar::bool(true, NonNullable),
59+
max: Scalar::bool(true, NonNullable),
5860
}));
5961
};
6062
}
6163
}
6264

6365
Ok(Some(MinMaxResult {
64-
min: Scalar::bool(false, array.dtype().nullability()),
65-
max: Scalar::bool(true, array.dtype().nullability()),
66+
min: Scalar::bool(false, NonNullable),
67+
max: Scalar::bool(true, NonNullable),
6668
}))
6769
}
6870
}
@@ -71,47 +73,47 @@ register_kernel!(MinMaxKernelAdapter(BoolVTable).lift());
7173

7274
#[cfg(test)]
7375
mod tests {
74-
use vortex_dtype::{DType, Nullability};
76+
use Nullability::NonNullable;
77+
use vortex_dtype::Nullability;
7578
use vortex_scalar::Scalar;
7679

7780
use crate::arrays::BoolArray;
7881
use crate::compute::{MinMaxResult, min_max};
7982

8083
#[test]
8184
fn test_min_max_nulls() {
82-
let dtype = DType::Bool(Nullability::Nullable);
8385
assert_eq!(
8486
min_max(BoolArray::from_iter(vec![Some(true), Some(true), None, None]).as_ref())
8587
.unwrap(),
8688
Some(MinMaxResult {
87-
min: Scalar::new(dtype.clone(), true.into()),
88-
max: Scalar::new(dtype.clone(), true.into()),
89+
min: Scalar::bool(true, NonNullable),
90+
max: Scalar::bool(true, NonNullable),
8991
})
9092
);
9193

9294
assert_eq!(
9395
min_max(BoolArray::from_iter(vec![None, Some(true), Some(true)]).as_ref()).unwrap(),
9496
Some(MinMaxResult {
95-
min: Scalar::new(dtype.clone(), true.into()),
96-
max: Scalar::new(dtype.clone(), true.into()),
97+
min: Scalar::bool(true, NonNullable),
98+
max: Scalar::bool(true, NonNullable),
9799
})
98100
);
99101

100102
assert_eq!(
101103
min_max(BoolArray::from_iter(vec![None, Some(true), Some(true), None]).as_ref())
102104
.unwrap(),
103105
Some(MinMaxResult {
104-
min: Scalar::new(dtype.clone(), true.into()),
105-
max: Scalar::new(dtype.clone(), true.into()),
106+
min: Scalar::bool(true, NonNullable),
107+
max: Scalar::bool(true, NonNullable),
106108
})
107109
);
108110

109111
assert_eq!(
110112
min_max(BoolArray::from_iter(vec![Some(false), Some(false), None, None]).as_ref())
111113
.unwrap(),
112114
Some(MinMaxResult {
113-
min: Scalar::new(dtype.clone(), false.into()),
114-
max: Scalar::new(dtype, false.into()),
115+
min: Scalar::bool(false, NonNullable),
116+
max: Scalar::bool(false, NonNullable),
115117
})
116118
);
117119
}

vortex-array/src/arrays/decimal/compute/min_max.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5-
use vortex_dtype::{DType, NativeDecimalType, match_each_decimal_value_type};
5+
use vortex_dtype::Nullability::NonNullable;
6+
use vortex_dtype::{DecimalDType, NativeDecimalType, match_each_decimal_value_type};
67
use vortex_error::VortexResult;
78
use vortex_mask::Mask;
8-
use vortex_scalar::{DecimalValue, Scalar, ScalarValue};
9+
use vortex_scalar::{DecimalValue, Scalar};
910

1011
use crate::arrays::{DecimalArray, DecimalVTable};
1112
use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
@@ -27,35 +28,38 @@ where
2728
D: Into<DecimalValue> + NativeDecimalType,
2829
{
2930
Ok(match array.validity_mask() {
30-
Mask::AllTrue(_) => compute_min_max(array.buffer::<D>().iter(), array.dtype()),
31+
Mask::AllTrue(_) => compute_min_max(array.buffer::<D>().iter(), array.decimal_dtype()),
3132
Mask::AllFalse(_) => None,
3233
Mask::Values(v) => compute_min_max(
3334
array
3435
.buffer::<D>()
3536
.iter()
3637
.zip(v.bit_buffer().iter())
3738
.filter_map(|(v, m)| m.then_some(v)),
38-
array.dtype(),
39+
array.decimal_dtype(),
3940
),
4041
})
4142
}
4243

43-
fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>, dtype: &DType) -> Option<MinMaxResult>
44+
fn compute_min_max<'a, T>(
45+
iter: impl Iterator<Item = &'a T>,
46+
decimal_dtype: DecimalDType,
47+
) -> Option<MinMaxResult>
4448
where
4549
T: Into<DecimalValue> + NativeDecimalType + Ord + Copy + 'a,
4650
{
4751
match iter.minmax_by(|a, b| a.cmp(b)) {
4852
itertools::MinMaxResult::NoElements => None,
4953
itertools::MinMaxResult::OneElement(&x) => {
50-
let scalar = Scalar::new(dtype.clone(), ScalarValue::from(x.into()));
54+
let scalar = Scalar::decimal(x.into(), decimal_dtype, NonNullable);
5155
Some(MinMaxResult {
5256
min: scalar.clone(),
5357
max: scalar,
5458
})
5559
}
5660
itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
57-
min: Scalar::new(dtype.clone(), ScalarValue::from(min.into())),
58-
max: Scalar::new(dtype.clone(), ScalarValue::from(max.into())),
61+
min: Scalar::decimal(min.into(), decimal_dtype, NonNullable),
62+
max: Scalar::decimal(max.into(), decimal_dtype, NonNullable),
5963
}),
6064
}
6165
}
@@ -80,13 +84,14 @@ mod tests {
8084

8185
let min_max = min_max(decimal.as_ref()).unwrap();
8286

87+
let non_nullable_dtype = decimal.dtype().as_nonnullable();
8388
let expected = MinMaxResult {
8489
min: Scalar::new(
85-
decimal.dtype().clone(),
90+
non_nullable_dtype.clone(),
8691
ScalarValue::from(DecimalValue::from(100i32)),
8792
),
8893
max: Scalar::new(
89-
decimal.dtype().clone(),
94+
non_nullable_dtype,
9095
ScalarValue::from(DecimalValue::from(200i32)),
9196
),
9297
};

vortex-array/src/arrays/extension/compute/min_max.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
6+
use vortex_dtype::Nullability;
47
use vortex_error::VortexResult;
58
use vortex_scalar::Scalar;
69

@@ -10,10 +13,12 @@ use crate::register_kernel;
1013

1114
impl MinMaxKernel for ExtensionVTable {
1215
fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
16+
let non_nullable_ext_dtype =
17+
Arc::new(array.ext_dtype().with_nullability(Nullability::NonNullable));
1318
Ok(
1419
compute::min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
15-
min: Scalar::extension(array.ext_dtype().clone(), min),
16-
max: Scalar::extension(array.ext_dtype().clone(), max),
20+
min: Scalar::extension(non_nullable_ext_dtype.clone(), min),
21+
max: Scalar::extension(non_nullable_ext_dtype, max),
1722
}),
1823
)
1924
}

vortex-array/src/arrays/primitive/compute/min_max.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5-
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
5+
use vortex_dtype::Nullability::NonNullable;
6+
use vortex_dtype::{NativePType, match_each_native_ptype};
67
use vortex_error::VortexResult;
78
use vortex_mask::Mask;
8-
use vortex_scalar::{Scalar, ScalarValue};
9+
use vortex_scalar::{PValue, Scalar};
910

1011
use crate::arrays::{PrimitiveArray, PrimitiveVTable};
1112
use crate::compute::{MinMaxKernel, MinMaxKernelAdapter, MinMaxResult};
@@ -24,25 +25,26 @@ register_kernel!(MinMaxKernelAdapter(PrimitiveVTable).lift());
2425
#[inline]
2526
fn compute_min_max_with_validity<T>(array: &PrimitiveArray) -> VortexResult<Option<MinMaxResult>>
2627
where
27-
T: Into<ScalarValue> + NativePType,
28+
T: NativePType,
29+
PValue: From<T>,
2830
{
2931
Ok(match array.validity_mask() {
30-
Mask::AllTrue(_) => compute_min_max(array.as_slice::<T>().iter(), array.dtype()),
32+
Mask::AllTrue(_) => compute_min_max(array.as_slice::<T>().iter()),
3133
Mask::AllFalse(_) => None,
3234
Mask::Values(v) => compute_min_max(
3335
array
3436
.as_slice::<T>()
3537
.iter()
3638
.zip(v.bit_buffer().iter())
3739
.filter_map(|(v, m)| m.then_some(v)),
38-
array.dtype(),
3940
),
4041
})
4142
}
4243

43-
fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>, dtype: &DType) -> Option<MinMaxResult>
44+
fn compute_min_max<'a, T>(iter: impl Iterator<Item = &'a T>) -> Option<MinMaxResult>
4445
where
45-
T: Into<ScalarValue> + NativePType,
46+
T: NativePType,
47+
PValue: From<T>,
4648
{
4749
// `total_compare` function provides a total ordering (even for NaN values).
4850
// However, we exclude NaNs from min max as they're not useful for any purpose where min/max would be used
@@ -52,15 +54,15 @@ where
5254
{
5355
itertools::MinMaxResult::NoElements => None,
5456
itertools::MinMaxResult::OneElement(&x) => {
55-
let scalar = Scalar::new(dtype.clone(), x.into());
57+
let scalar = Scalar::primitive(x, NonNullable);
5658
Some(MinMaxResult {
5759
min: scalar.clone(),
5860
max: scalar,
5961
})
6062
}
6163
itertools::MinMaxResult::MinMax(&min, &max) => Some(MinMaxResult {
62-
min: Scalar::new(dtype.clone(), min.into()),
63-
max: Scalar::new(dtype.clone(), max.into()),
64+
min: Scalar::primitive(min, NonNullable),
65+
max: Scalar::primitive(max, NonNullable),
6466
}),
6567
}
6668
}

vortex-array/src/arrays/varbin/compute/min_max.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use itertools::Itertools;
55
use vortex_dtype::DType;
6+
use vortex_dtype::Nullability::NonNullable;
67
use vortex_error::{VortexResult, vortex_panic};
78
use vortex_scalar::Scalar;
89

@@ -45,12 +46,12 @@ pub(crate) fn varbin_compute_min_max<T: ArrayAccessor<[u8]>>(
4546
/// Helper function to make sure that min/max has the right [`Scalar`] type.
4647
fn make_scalar(dtype: &DType, value: &[u8]) -> Scalar {
4748
match dtype {
48-
DType::Binary(_) => Scalar::new(dtype.clone(), value.into()),
49+
DType::Binary(_) => Scalar::binary(value.to_vec(), NonNullable),
4950
DType::Utf8(_) => {
5051
// SAFETY: We only call `compute_min_max` within `varbin/`, in which we always validate
5152
// the arrays, and we always pass `array.dtype()` in as the `dtype` argument.
5253
let value = unsafe { str::from_utf8_unchecked(value) };
53-
Scalar::new(dtype.clone(), value.into())
54+
Scalar::utf8(value, NonNullable)
5455
}
5556
_ => vortex_panic!("cannot make Scalar from bytes with dtype {dtype}"),
5657
}
@@ -60,7 +61,7 @@ fn make_scalar(dtype: &DType, value: &[u8]) -> Scalar {
6061
mod tests {
6162
use vortex_buffer::BufferString;
6263
use vortex_dtype::DType::Utf8;
63-
use vortex_dtype::Nullability::Nullable;
64+
use vortex_dtype::Nullability::{NonNullable, Nullable};
6465
use vortex_scalar::Scalar;
6566

6667
use crate::arrays::VarBinArray;
@@ -83,14 +84,14 @@ mod tests {
8384
assert_eq!(
8485
min,
8586
Scalar::new(
86-
Utf8(Nullable),
87+
Utf8(NonNullable),
8788
BufferString::from("hello world".to_string()).into(),
8889
)
8990
);
9091
assert_eq!(
9192
max,
9293
Scalar::new(
93-
Utf8(Nullable),
94+
Utf8(NonNullable),
9495
BufferString::from("hello world this is a long string".to_string()).into()
9596
)
9697
);

vortex-array/src/compute/min_max.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ impl ComputeFnVTable for MinMax {
110110
Ok(DType::Struct(
111111
StructFields::new(
112112
["min", "max"].into(),
113-
vec![array.dtype().clone(), array.dtype().clone()],
113+
vec![
114+
array.dtype().as_nonnullable(),
115+
array.dtype().as_nonnullable(),
116+
],
114117
),
115118
Nullability::Nullable,
116119
))
@@ -133,13 +136,8 @@ fn min_max_impl(
133136
return Ok(None);
134137
}
135138

136-
if let Some(array) = array.as_opt::<ConstantVTable>()
137-
&& !array.scalar().is_null()
138-
{
139-
return Ok(Some(MinMaxResult {
140-
min: array.scalar().clone(),
141-
max: array.scalar().clone(),
142-
}));
139+
if let Some(array) = array.as_opt::<ConstantVTable>() {
140+
return ConstantVTable.min_max(array);
143141
}
144142

145143
let min = array
@@ -152,7 +150,11 @@ fn min_max_impl(
152150
.and_then(Precision::as_exact);
153151

154152
if let Some((min, max)) = min.zip(max) {
155-
return Ok(Some(MinMaxResult { min, max }));
153+
let non_nullable_dtype = array.dtype().as_nonnullable();
154+
return Ok(Some(MinMaxResult {
155+
min: min.cast(&non_nullable_dtype)?,
156+
max: max.cast(&non_nullable_dtype)?,
157+
}));
156158
}
157159

158160
let args = InvocationArgs {
@@ -199,10 +201,11 @@ impl<V: VTable + MinMaxKernel> Kernel for MinMaxKernelAdapter<V> {
199201
let Some(array) = inputs.array.as_opt::<V>() else {
200202
return Ok(None);
201203
};
204+
let non_nullable_dtype = array.dtype().as_nonnullable();
202205
let dtype = DType::Struct(
203206
StructFields::new(
204207
["min", "max"].into(),
205-
vec![array.dtype().clone(), array.dtype().clone()],
208+
vec![non_nullable_dtype.clone(), non_nullable_dtype],
206209
),
207210
Nullability::Nullable,
208211
);

0 commit comments

Comments
 (0)