Skip to content

Commit f69f00c

Browse files
fix: return nullable type from between primitive comput fn (#2776)
1 parent 6cdaba2 commit f69f00c

File tree

4 files changed

+68
-16
lines changed

4 files changed

+68
-16
lines changed

encodings/dict/src/compute/compare.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ impl CompareFn<&DictArray> for DictEncoding {
2424
&ConstantArray::new(rhs, lhs.values().len()),
2525
operator,
2626
)?;
27-
2827
return if operator == Operator::Eq {
2928
let result_nullability =
3029
compare_result.dtype().nullability() | lhs.dtype().nullability();

vortex-array/src/arrays/primitive/compute/between.rs

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use arrow_buffer::BooleanBuffer;
2-
use vortex_dtype::{NativePType, match_each_native_ptype};
2+
use vortex_dtype::{NativePType, Nullability, match_each_native_ptype};
33
use vortex_error::VortexResult;
44

55
use crate::arrays::{BoolArray, PrimitiveArray, PrimitiveEncoding};
@@ -19,8 +19,14 @@ impl BetweenFn<&PrimitiveArray> for PrimitiveEncoding {
1919
return Ok(None);
2020
};
2121

22+
// Note, we know that have checked before that the lower and upper bounds are not constant
23+
// null values
24+
25+
let nullability =
26+
arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
27+
2228
Ok(Some(match_each_native_ptype!(arr.ptype(), |$P| {
23-
between_impl::<$P>(arr, $P::try_from(lower)?, $P::try_from(upper)?, options)
29+
between_impl::<$P>(arr, $P::try_from(lower)?, $P::try_from(upper)?, nullability, options)
2430
})))
2531
}
2632
}
@@ -29,22 +35,43 @@ fn between_impl<T: NativePType + Copy>(
2935
arr: &PrimitiveArray,
3036
lower: T,
3137
upper: T,
38+
nullability: Nullability,
3239
options: &BetweenOptions,
3340
) -> ArrayRef {
3441
match (options.lower_strict, options.upper_strict) {
3542
// Note: these comparisons are explicitly passed in to allow function impl inlining
36-
(StrictComparison::Strict, StrictComparison::Strict) => {
37-
between_impl_(arr, lower, NativePType::is_lt, upper, NativePType::is_lt)
38-
}
39-
(StrictComparison::Strict, StrictComparison::NonStrict) => {
40-
between_impl_(arr, lower, NativePType::is_lt, upper, NativePType::is_le)
41-
}
42-
(StrictComparison::NonStrict, StrictComparison::Strict) => {
43-
between_impl_(arr, lower, NativePType::is_le, upper, NativePType::is_lt)
44-
}
45-
(StrictComparison::NonStrict, StrictComparison::NonStrict) => {
46-
between_impl_(arr, lower, NativePType::is_le, upper, NativePType::is_le)
47-
}
43+
(StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
44+
arr,
45+
lower,
46+
NativePType::is_lt,
47+
upper,
48+
NativePType::is_lt,
49+
nullability,
50+
),
51+
(StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
52+
arr,
53+
lower,
54+
NativePType::is_lt,
55+
upper,
56+
NativePType::is_le,
57+
nullability,
58+
),
59+
(StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
60+
arr,
61+
lower,
62+
NativePType::is_le,
63+
upper,
64+
NativePType::is_lt,
65+
nullability,
66+
),
67+
(StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
68+
arr,
69+
lower,
70+
NativePType::is_le,
71+
upper,
72+
NativePType::is_le,
73+
nullability,
74+
),
4875
}
4976
}
5077

@@ -54,6 +81,7 @@ fn between_impl_<T>(
5481
lower_fn: impl Fn(T, T) -> bool,
5582
upper: T,
5683
upper_fn: impl Fn(T, T) -> bool,
84+
nullability: Nullability,
5785
) -> ArrayRef
5886
where
5987
T: NativePType + Copy,
@@ -65,7 +93,7 @@ where
6593
let i = unsafe { *slice.get_unchecked(idx) };
6694
lower_fn(lower, i) & upper_fn(i, upper)
6795
}),
68-
arr.validity().clone(),
96+
arr.validity().clone().union_nullability(nullability),
6997
)
7098
.into_array()
7199
}

vortex-array/src/compute/between.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use vortex_dtype::{DType, Nullability};
22
use vortex_error::{VortexExpect, VortexResult};
3+
use vortex_scalar::Scalar;
34

5+
use crate::arrays::ConstantArray;
46
use crate::compute::{BinaryOperator, Operator, binary_boolean, compare};
57
use crate::{Array, ArrayRef, Canonical, Encoding, IntoArray};
68

@@ -91,6 +93,21 @@ pub fn between(
9193
debug_assert_eq!(arr.len(), lower.len());
9294
debug_assert_eq!(arr.len(), upper.len());
9395

96+
// A quick check to see if either array might is a null constant array.
97+
if lower.is_invalid(0)? || upper.is_invalid(0)? {
98+
if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
99+
if c_lower.is_null() || c_upper.is_null() {
100+
return Ok(ConstantArray::new(
101+
Scalar::null(arr.dtype().with_nullability(
102+
lower.dtype().nullability() | upper.dtype().nullability(),
103+
)),
104+
arr.len(),
105+
)
106+
.to_array());
107+
}
108+
}
109+
}
110+
94111
let result = between_impl(arr, lower, upper, options)?;
95112

96113
debug_assert_eq!(result.len(), arr.len());

vortex-array/src/validity.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ impl Validity {
7676
}
7777
}
7878

79+
/// The union nullability and validity.
80+
pub fn union_nullability(self, nullability: Nullability) -> Self {
81+
match nullability {
82+
Nullability::NonNullable => self,
83+
Nullability::Nullable => self.into_nullable(),
84+
}
85+
}
86+
7987
pub fn all_valid(&self) -> VortexResult<bool> {
8088
Ok(match self {
8189
Validity::NonNullable | Validity::AllValid => true,

0 commit comments

Comments
 (0)