Skip to content

Commit 73ea6d7

Browse files
authored
feat: list_contains, list_lens (#3169)
1 parent feacdff commit 73ea6d7

File tree

5 files changed

+280
-6
lines changed

5 files changed

+280
-6
lines changed

vortex-array/src/compute/invert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output}
88
use crate::encoding::Encoding;
99
use crate::{Array, ArrayRef, ToCanonical};
1010

11-
/// Logically invert a boolean array.
11+
/// Logically invert a boolean array, preserving its validity.
1212
pub fn invert(array: &dyn Array) -> VortexResult<ArrayRef> {
1313
INVERT_FN
1414
.invoke(&InvocationArgs {

vortex-array/src/compute/list.rs

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
//! List-related compute operations.
2+
3+
use arrow_buffer::BooleanBuffer;
4+
use arrow_buffer::bit_iterator::BitIndexIterator;
5+
use num_traits::AsPrimitive;
6+
use vortex_buffer::Buffer;
7+
use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8+
use vortex_error::{VortexResult, vortex_bail};
9+
use vortex_mask::Mask;
10+
use vortex_scalar::Scalar;
11+
12+
use crate::arrays::{BoolArray, ConstantArray, ListArray};
13+
use crate::compute::{Operator, compare, invert};
14+
use crate::validity::Validity;
15+
use crate::variants::PrimitiveArrayTrait;
16+
use crate::{Array, ArrayRef, ArrayStatistics, IntoArray, ToCanonical};
17+
18+
/// Compute a `Bool`-typed array the same length as `array` where elements are `true` if the list
19+
/// item contains the `value`, or `false` otherwise.
20+
///
21+
/// If the ListArray is nullable, then the result will contain nulls matching the null mask
22+
/// of the original array.
23+
///
24+
/// ## Null scalar handling
25+
///
26+
/// When the search scalar is `NULL`, then the resulting array will be a `BoolArray` containing
27+
/// `true` if the list contains any nulls, and `false` if the list does not contain any nulls,
28+
/// or `NULL` for null lists.
29+
///
30+
/// ## Example
31+
///
32+
/// ```rust
33+
/// use vortex_array::{Array, IntoArray, ToCanonical};
34+
/// use vortex_array::arrays::{ListArray, VarBinArray};
35+
/// use vortex_array::compute::list_contains;
36+
/// use vortex_array::validity::Validity;
37+
/// use vortex_buffer::buffer;
38+
/// use vortex_dtype::DType;
39+
/// let elements = VarBinArray::from_vec(
40+
/// vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
41+
/// let offsets = buffer![0u32, 1, 3, 5].into_array();
42+
/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
43+
///
44+
/// let matches = list_contains(&list_array, "b".into()).unwrap();
45+
/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
46+
/// assert_eq!(to_vec, vec![false, true, false]);
47+
/// ```
48+
pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49+
if value.is_null() {
50+
return list_contains_null(array);
51+
}
52+
53+
// Ensure that the array must be of List type.
54+
let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else {
55+
vortex_bail!("array must be of List type")
56+
};
57+
58+
let elems = list_array.elements();
59+
let ends = list_array.offsets().to_primitive()?;
60+
61+
let rhs = ConstantArray::new(value, elems.len());
62+
let matching_elements = compare(elems, &rhs, Operator::Eq)?;
63+
let matches = matching_elements.to_bool()?;
64+
65+
// Fast path: all elements match or none match.
66+
if let Some(pred) = matches.as_constant() {
67+
return match pred.as_bool().value() {
68+
// TODO(aduffy): how do we handle null?
69+
None | Some(false) => Ok(ConstantArray::new::<bool>(false, matches.len()).into_array()),
70+
Some(true) => Ok(ConstantArray::new::<bool>(true, matches.len()).into_array()),
71+
};
72+
}
73+
74+
match_each_integer_ptype!(ends.ptype(), |$T| {
75+
Ok(reduce_with_ends(ends.as_slice::<$T>(), &matches.boolean_buffer(), list_array.validity().clone()))
76+
})
77+
}
78+
79+
/// Returns a `Bool` array with `true` for lists which contains NULL and `false` if not, or
80+
/// NULL if the list itself is null.
81+
pub fn list_contains_null(array: &dyn Array) -> VortexResult<ArrayRef> {
82+
// Ensure that the array must be of List type.
83+
let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else {
84+
vortex_bail!("array must be of List type")
85+
};
86+
87+
let elems = list_array.elements();
88+
89+
// Check element validity. We need to intersect
90+
match elems.validity_mask()? {
91+
// No NULL elements
92+
Mask::AllTrue(_) => match list_array.validity() {
93+
Validity::NonNullable => {
94+
Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
95+
}
96+
Validity::AllValid => Ok(ConstantArray::new(
97+
Scalar::bool(true, Nullability::Nullable),
98+
list_array.len(),
99+
)
100+
.into_array()),
101+
Validity::AllInvalid => Ok(ConstantArray::new(
102+
Scalar::null(DType::Bool(Nullability::Nullable)),
103+
list_array.len(),
104+
)
105+
.into_array()),
106+
Validity::Array(list_mask) => {
107+
// Create a new bool array with false, and the provided nulls
108+
let buffer = BooleanBuffer::new_unset(list_array.len());
109+
Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
110+
}
111+
},
112+
// All null elements
113+
Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
114+
Mask::Values(mask) => {
115+
let nulls = invert(&mask.into_array())?.to_bool()?;
116+
let ends = list_array.offsets().to_primitive()?;
117+
match_each_integer_ptype!(ends.ptype(), |$T| {
118+
Ok(reduce_with_ends(
119+
list_array.offsets().to_primitive()?.as_slice::<$T>(),
120+
&nulls.boolean_buffer(),
121+
list_array.validity().clone(),
122+
))
123+
})
124+
}
125+
}
126+
}
127+
128+
// Reduce each boolean values into a Mask that indicates which elements in the
129+
// ListArray contain the matching value.
130+
fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
131+
ends: &[T],
132+
matches: &BooleanBuffer,
133+
validity: Validity,
134+
) -> ArrayRef {
135+
let mask: BooleanBuffer = ends
136+
.windows(2)
137+
.map(|window| {
138+
let len = window[1].as_() - window[0].as_();
139+
let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
140+
set_bits.next().is_some()
141+
})
142+
.collect();
143+
144+
BoolArray::new(mask, validity).into_array()
145+
}
146+
147+
/// Returns a new array of `u64` representing the length of each list element.
148+
///
149+
/// ## Example
150+
///
151+
/// ```rust
152+
/// use vortex_array::arrays::{ListArray, VarBinArray};
153+
/// use vortex_array::{Array, IntoArray};
154+
/// use vortex_array::compute::{list_elem_len, scalar_at};
155+
/// use vortex_array::validity::Validity;
156+
/// use vortex_buffer::buffer;
157+
/// use vortex_dtype::DType;
158+
///
159+
/// let elements = VarBinArray::from_vec(
160+
/// vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
161+
/// let offsets = buffer![0u32, 1, 3, 5].into_array();
162+
/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
163+
///
164+
/// let lens = list_elem_len(&list_array).unwrap();
165+
/// assert_eq!(scalar_at(&lens, 0).unwrap(), 1u32.into());
166+
/// assert_eq!(scalar_at(&lens, 1).unwrap(), 2u32.into());
167+
/// assert_eq!(scalar_at(&lens, 2).unwrap(), 2u32.into());
168+
/// ```
169+
pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
170+
let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else {
171+
vortex_bail!("array must be of List type")
172+
};
173+
174+
let offsets = list_array.offsets().to_primitive()?;
175+
let lens_array = match_each_integer_ptype!(offsets.ptype(), |$T| {
176+
element_lens(offsets.as_slice::<$T>()).into_array()
177+
});
178+
179+
Ok(lens_array)
180+
}
181+
182+
fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
183+
values
184+
.windows(2)
185+
.map(|window| window[1] - window[0])
186+
.collect()
187+
}
188+
189+
#[cfg(test)]
190+
mod tests {
191+
use std::sync::Arc;
192+
193+
use itertools::Itertools;
194+
use rstest::rstest;
195+
use vortex_buffer::Buffer;
196+
use vortex_dtype::{DType, Nullability};
197+
use vortex_scalar::Scalar;
198+
199+
use crate::array::IntoArray;
200+
use crate::arrays::{BoolArray, ListArray, VarBinArray};
201+
use crate::canonical::ToCanonical;
202+
use crate::compute::list_contains;
203+
use crate::validity::Validity;
204+
use crate::{Array, ArrayRef};
205+
206+
fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
207+
ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
208+
.unwrap()
209+
.into_array()
210+
}
211+
212+
fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
213+
let elements = values.iter().flatten().cloned().collect_vec();
214+
let mut offsets = values
215+
.iter()
216+
.scan(0u64, |st, v| {
217+
*st += v.len() as u64;
218+
Some(*st)
219+
})
220+
.collect_vec();
221+
offsets.insert(0, 0u64);
222+
let offsets = Buffer::from_iter(offsets).into_array();
223+
224+
let elements =
225+
VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
226+
227+
ListArray::try_new(elements, offsets, Validity::NonNullable)
228+
.unwrap()
229+
.into_array()
230+
}
231+
232+
fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
233+
let validity = match validity {
234+
None => Validity::NonNullable,
235+
Some(v) => Validity::from_iter(v),
236+
};
237+
238+
BoolArray::new(values.into_iter().collect(), validity)
239+
}
240+
241+
#[rstest]
242+
// Case 1: list(utf8)
243+
#[case(
244+
nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
245+
Some("a"),
246+
bool_array(vec![false, true, true], None)
247+
)]
248+
// Case 2: list(utf8?) with NULL search scalar
249+
#[case(
250+
null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
251+
None,
252+
bool_array(vec![false, true, true], None)
253+
)]
254+
fn test_contains_nullable(
255+
#[case] list_array: ArrayRef,
256+
#[case] value: Option<&str>,
257+
#[case] expected: BoolArray,
258+
) {
259+
let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
260+
let scalar = match value {
261+
None => Scalar::null(DType::Utf8(Nullability::Nullable)),
262+
Some(v) => Scalar::utf8(v, element_nullability),
263+
};
264+
let result = list_contains(&list_array, scalar).expect("list_contains failed");
265+
let bool_result = result.to_bool().expect("to_bool failed");
266+
assert_eq!(
267+
bool_result.boolean_buffer().iter().collect_vec(),
268+
expected.boolean_buffer().iter().collect_vec()
269+
);
270+
assert_eq!(bool_result.validity(), expected.validity());
271+
}
272+
}

vortex-array/src/compute/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub use is_constant::*;
2121
pub use is_sorted::*;
2222
use itertools::Itertools;
2323
pub use like::{LikeFn, LikeOptions, like};
24+
pub use list::*;
2425
pub use mask::*;
2526
pub use min_max::{MinMaxFn, MinMaxResult, min_max};
2627
pub use nan_count::*;
@@ -57,6 +58,7 @@ mod invert;
5758
mod is_constant;
5859
mod is_sorted;
5960
mod like;
61+
mod list;
6062
mod mask;
6163
mod min_max;
6264
mod nan_count;

vortex-array/src/compute/search_sorted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ pub fn search_sorted_many<T: Into<Scalar> + Clone>(
307307
.try_collect()
308308
}
309309

310-
// Native functions for each of the values, cast up to u64 or down to something lower.
310+
/// Search for many `usize` values in a sorted array.
311311
pub fn search_sorted_usize_many(
312312
array: &dyn Array,
313313
targets: &[usize],

vortex-array/src/variants.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
1111
use crate::compute::sum;
1212
use crate::{Array, ArrayRef};
1313

14-
pub trait NullArrayTrait {}
14+
pub trait NullArrayTrait: Array {}
1515

1616
pub trait BoolArrayTrait: Array {}
1717

@@ -40,9 +40,9 @@ pub trait PrimitiveArrayTrait: Array {
4040
}
4141
}
4242

43-
pub trait Utf8ArrayTrait {}
43+
pub trait Utf8ArrayTrait: Array {}
4444

45-
pub trait BinaryArrayTrait {}
45+
pub trait BinaryArrayTrait: Array {}
4646

4747
pub trait DecimalArrayTrait: Array {}
4848

@@ -90,7 +90,7 @@ impl dyn StructArrayTrait + '_ {
9090
}
9191
}
9292

93-
pub trait ListArrayTrait {}
93+
pub trait ListArrayTrait: Array {}
9494

9595
pub trait ExtensionArrayTrait: Array {
9696
/// Returns the extension logical [`DType`].

0 commit comments

Comments
 (0)