Skip to content

Commit d17f02f

Browse files
delta003a10y
andauthored
Adds list_contains expression (#3410)
Co-authored-by: Andrew Duffy <[email protected]>
1 parent 91624db commit d17f02f

File tree

5 files changed

+377
-29
lines changed

5 files changed

+377
-29
lines changed

vortex-array/src/compute/list.rs

Lines changed: 129 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use crate::{Array, ArrayRef, IntoArray, ToCanonical};
4646
/// assert_eq!(to_vec, vec![false, true, false]);
4747
/// ```
4848
pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49-
let DType::List(elem_dtype, _nullability) = array.dtype() else {
49+
let DType::List(elem_dtype, nullability) = array.dtype() else {
5050
vortex_bail!("Array must be of List type");
5151
};
5252
if &**elem_dtype != value.dtype() {
@@ -68,20 +68,45 @@ pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef>
6868
}
6969

7070
let elems = list_array.elements();
71-
let ends = list_array.offsets().to_primitive()?;
71+
if elems.is_empty() {
72+
// Must return false when a list is empty (but valid), or null when the list itself is null.
73+
return list_false_or_null(&list_array);
74+
}
7275

7376
let rhs = ConstantArray::new(value, elems.len());
7477
let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
7578
let matches = matching_elements.to_bool()?;
7679

7780
// Fast path: no elements match.
7881
if let Some(pred) = matches.as_constant() {
79-
if matches!(pred.as_bool().value(), None | Some(false)) {
80-
// TODO(aduffy): how do we handle null?
81-
return Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array());
82-
}
82+
return match pred.as_bool().value() {
83+
// All comparisons are invalid (result in `null`), and search is not null because
84+
// we already checked for null above.
85+
None => {
86+
assert!(
87+
!rhs.scalar().is_null(),
88+
"Search value must not be null here"
89+
);
90+
// False, unless the list itself is null in which case we return null.
91+
list_false_or_null(&list_array)
92+
}
93+
// No elements match, and all comparisons are valid (result in `false`).
94+
Some(false) => {
95+
// False, but match the nullability to the input list array.
96+
Ok(
97+
ConstantArray::new(Scalar::bool(false, *nullability), list_array.len())
98+
.into_array(),
99+
)
100+
}
101+
// All elements match, and all comparisons are valid (result in `true`).
102+
Some(true) => {
103+
// True, unless the list itself is empty or NULL.
104+
list_is_not_empty(&list_array)
105+
}
106+
};
83107
}
84108

109+
let ends = list_array.offsets().to_primitive()?;
85110
match_each_integer_ptype!(ends.ptype(), |T| {
86111
Ok(reduce_with_ends(
87112
ends.as_slice::<T>(),
@@ -99,28 +124,15 @@ fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
99124
// Check element validity. We need to intersect
100125
match elems.validity_mask()? {
101126
// No NULL elements
102-
Mask::AllTrue(_) => match list_array.validity() {
103-
Validity::NonNullable => {
104-
Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
105-
}
106-
Validity::AllValid => Ok(ConstantArray::new(
107-
Scalar::bool(true, Nullability::Nullable),
108-
list_array.len(),
109-
)
110-
.into_array()),
111-
Validity::AllInvalid => Ok(ConstantArray::new(
112-
Scalar::null(DType::Bool(Nullability::Nullable)),
113-
list_array.len(),
114-
)
115-
.into_array()),
116-
Validity::Array(list_mask) => {
117-
// Create a new bool array with false, and the provided nulls
118-
let buffer = BooleanBuffer::new_unset(list_array.len());
119-
Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
120-
}
121-
},
122-
// All null elements
123-
Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
127+
Mask::AllTrue(_) => {
128+
// False, unless the list itself is NULL.
129+
list_false_or_null(list_array)
130+
}
131+
// All NULL elements.
132+
Mask::AllFalse(_) => {
133+
// True, unless the list itself is empty or NULL.
134+
list_is_not_empty(list_array)
135+
}
124136
Mask::Values(mask) => {
125137
let nulls = invert(&mask.into_array())?.to_bool()?;
126138
let ends = list_array.offsets().to_primitive()?;
@@ -135,6 +147,58 @@ fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
135147
}
136148
}
137149

150+
/// Returns a `Bool` array with `false` for lists that are valid,
151+
/// or `NULL` if the list itself is null.
152+
fn list_false_or_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
153+
match list_array.validity() {
154+
Validity::NonNullable => {
155+
// All false.
156+
Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
157+
}
158+
Validity::AllValid => {
159+
// All false, but nullable.
160+
Ok(
161+
ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
162+
.into_array(),
163+
)
164+
}
165+
Validity::AllInvalid => {
166+
// All nulls, must be nullable result.
167+
Ok(ConstantArray::new(
168+
Scalar::null(DType::Bool(Nullability::Nullable)),
169+
list_array.len(),
170+
)
171+
.into_array())
172+
}
173+
Validity::Array(validity_array) => {
174+
// Create a new bool array with false, and the provided nulls
175+
let buffer = BooleanBuffer::new_unset(list_array.len());
176+
Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
177+
}
178+
}
179+
}
180+
181+
/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
182+
/// or `NULL` if the list itself is null.
183+
fn list_is_not_empty(list_array: &ListArray) -> VortexResult<ArrayRef> {
184+
// Short-circuit for all invalid.
185+
if matches!(list_array.validity(), Validity::AllInvalid) {
186+
return Ok(ConstantArray::new(
187+
Scalar::null(DType::Bool(Nullability::Nullable)),
188+
list_array.len(),
189+
)
190+
.into_array());
191+
}
192+
193+
let offsets = list_array.offsets().to_primitive()?;
194+
let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
195+
element_is_not_empty(offsets.as_slice::<T>())
196+
});
197+
198+
// Copy over the validity mask from the input.
199+
Ok(BoolArray::new(buffer, list_array.validity().clone()).into_array())
200+
}
201+
138202
// Reduce each boolean values into a Mask that indicates which elements in the
139203
// ListArray contain the matching value.
140204
fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
@@ -203,6 +267,10 @@ fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
203267
.collect()
204268
}
205269

270+
fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
271+
BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
272+
}
273+
206274
#[cfg(test)]
207275
mod tests {
208276
use std::sync::Arc;
@@ -285,6 +353,18 @@ mod tests {
285353
Some("a"),
286354
bool_array(vec![false, false, false], None)
287355
)]
356+
// Case 6: list(utf8?) with empty + NULL elements and NULL search
357+
#[case(
358+
null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
359+
None,
360+
bool_array(vec![false, true, true], None)
361+
)]
362+
// Case 7: list(utf8?) with empty + NULL elements and search scalar
363+
#[case(
364+
null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
365+
Some("a"),
366+
bool_array(vec![false, false, false], None)
367+
)]
288368
fn test_contains_nullable(
289369
#[case] list_array: ArrayRef,
290370
#[case] value: Option<&str>,
@@ -328,4 +408,25 @@ mod tests {
328408
vec![true, true],
329409
);
330410
}
411+
412+
#[test]
413+
fn test_all_nulls() {
414+
let list_array = ConstantArray::new(
415+
Scalar::null(DType::List(
416+
Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
417+
Nullability::Nullable,
418+
)),
419+
5,
420+
)
421+
.into_array();
422+
423+
let contains = list_contains(&list_array, 2i32.into()).unwrap();
424+
assert!(contains.is::<ConstantVTable>(), "Expected constant result");
425+
426+
assert_eq!(contains.len(), 5);
427+
assert_eq!(
428+
contains.to_bool().unwrap().validity(),
429+
&Validity::AllInvalid
430+
);
431+
}
331432
}

vortex-expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mod get_item;
1414
mod identity;
1515
mod is_null;
1616
mod like;
17+
mod list_contains;
1718
mod literal;
1819
mod merge;
1920
mod not;
@@ -33,6 +34,7 @@ pub use get_item::*;
3334
pub use identity::*;
3435
pub use is_null::*;
3536
pub use like::*;
37+
pub use list_contains::*;
3638
pub use literal::*;
3739
pub use merge::*;
3840
pub use not::*;

0 commit comments

Comments
 (0)