|
| 1 | +//! List-related compute operations. |
| 2 | +
|
| 3 | +use arrow_buffer::BooleanBuffer; |
| 4 | +use arrow_buffer::bit_iterator::BitIndexIterator; |
| 5 | +use num_traits::AsPrimitive; |
| 6 | +use vortex_buffer::Buffer; |
| 7 | +use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype}; |
| 8 | +use vortex_error::{VortexResult, vortex_bail}; |
| 9 | +use vortex_mask::Mask; |
| 10 | +use vortex_scalar::Scalar; |
| 11 | + |
| 12 | +use crate::arrays::{BoolArray, ConstantArray, ListArray}; |
| 13 | +use crate::compute::{Operator, compare, invert}; |
| 14 | +use crate::validity::Validity; |
| 15 | +use crate::variants::PrimitiveArrayTrait; |
| 16 | +use crate::{Array, ArrayRef, ArrayStatistics, IntoArray, ToCanonical}; |
| 17 | + |
| 18 | +/// Compute a `Bool`-typed array the same length as `array` where elements are `true` if the list |
| 19 | +/// item contains the `value`, or `false` otherwise. |
| 20 | +/// |
| 21 | +/// If the ListArray is nullable, then the result will contain nulls matching the null mask |
| 22 | +/// of the original array. |
| 23 | +/// |
| 24 | +/// ## Null scalar handling |
| 25 | +/// |
| 26 | +/// When the search scalar is `NULL`, then the resulting array will be a `BoolArray` containing |
| 27 | +/// `true` if the list contains any nulls, and `false` if the list does not contain any nulls, |
| 28 | +/// or `NULL` for null lists. |
| 29 | +/// |
| 30 | +/// ## Example |
| 31 | +/// |
| 32 | +/// ```rust |
| 33 | +/// use vortex_array::{Array, IntoArray, ToCanonical}; |
| 34 | +/// use vortex_array::arrays::{ListArray, VarBinArray}; |
| 35 | +/// use vortex_array::compute::list_contains; |
| 36 | +/// use vortex_array::validity::Validity; |
| 37 | +/// use vortex_buffer::buffer; |
| 38 | +/// use vortex_dtype::DType; |
| 39 | +/// let elements = VarBinArray::from_vec( |
| 40 | +/// vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array(); |
| 41 | +/// let offsets = buffer![0u32, 1, 3, 5].into_array(); |
| 42 | +/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap(); |
| 43 | +/// |
| 44 | +/// let matches = list_contains(&list_array, "b".into()).unwrap(); |
| 45 | +/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect(); |
| 46 | +/// assert_eq!(to_vec, vec![false, true, false]); |
| 47 | +/// ``` |
| 48 | +pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> { |
| 49 | + if value.is_null() { |
| 50 | + return list_contains_null(array); |
| 51 | + } |
| 52 | + |
| 53 | + // Ensure that the array must be of List type. |
| 54 | + let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else { |
| 55 | + vortex_bail!("array must be of List type") |
| 56 | + }; |
| 57 | + |
| 58 | + let elems = list_array.elements(); |
| 59 | + let ends = list_array.offsets().to_primitive()?; |
| 60 | + |
| 61 | + let rhs = ConstantArray::new(value, elems.len()); |
| 62 | + let matching_elements = compare(elems, &rhs, Operator::Eq)?; |
| 63 | + let matches = matching_elements.to_bool()?; |
| 64 | + |
| 65 | + // Fast path: all elements match or none match. |
| 66 | + if let Some(pred) = matches.as_constant() { |
| 67 | + return match pred.as_bool().value() { |
| 68 | + // TODO(aduffy): how do we handle null? |
| 69 | + None | Some(false) => Ok(ConstantArray::new::<bool>(false, matches.len()).into_array()), |
| 70 | + Some(true) => Ok(ConstantArray::new::<bool>(true, matches.len()).into_array()), |
| 71 | + }; |
| 72 | + } |
| 73 | + |
| 74 | + match_each_integer_ptype!(ends.ptype(), |$T| { |
| 75 | + Ok(reduce_with_ends(ends.as_slice::<$T>(), &matches.boolean_buffer(), list_array.validity().clone())) |
| 76 | + }) |
| 77 | +} |
| 78 | + |
| 79 | +/// Returns a `Bool` array with `true` for lists which contains NULL and `false` if not, or |
| 80 | +/// NULL if the list itself is null. |
| 81 | +pub fn list_contains_null(array: &dyn Array) -> VortexResult<ArrayRef> { |
| 82 | + // Ensure that the array must be of List type. |
| 83 | + let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else { |
| 84 | + vortex_bail!("array must be of List type") |
| 85 | + }; |
| 86 | + |
| 87 | + let elems = list_array.elements(); |
| 88 | + |
| 89 | + // Check element validity. We need to intersect |
| 90 | + match elems.validity_mask()? { |
| 91 | + // No NULL elements |
| 92 | + Mask::AllTrue(_) => match list_array.validity() { |
| 93 | + Validity::NonNullable => { |
| 94 | + Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array()) |
| 95 | + } |
| 96 | + Validity::AllValid => Ok(ConstantArray::new( |
| 97 | + Scalar::bool(true, Nullability::Nullable), |
| 98 | + list_array.len(), |
| 99 | + ) |
| 100 | + .into_array()), |
| 101 | + Validity::AllInvalid => Ok(ConstantArray::new( |
| 102 | + Scalar::null(DType::Bool(Nullability::Nullable)), |
| 103 | + list_array.len(), |
| 104 | + ) |
| 105 | + .into_array()), |
| 106 | + Validity::Array(list_mask) => { |
| 107 | + // Create a new bool array with false, and the provided nulls |
| 108 | + let buffer = BooleanBuffer::new_unset(list_array.len()); |
| 109 | + Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array()) |
| 110 | + } |
| 111 | + }, |
| 112 | + // All null elements |
| 113 | + Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()), |
| 114 | + Mask::Values(mask) => { |
| 115 | + let nulls = invert(&mask.into_array())?.to_bool()?; |
| 116 | + let ends = list_array.offsets().to_primitive()?; |
| 117 | + match_each_integer_ptype!(ends.ptype(), |$T| { |
| 118 | + Ok(reduce_with_ends( |
| 119 | + list_array.offsets().to_primitive()?.as_slice::<$T>(), |
| 120 | + &nulls.boolean_buffer(), |
| 121 | + list_array.validity().clone(), |
| 122 | + )) |
| 123 | + }) |
| 124 | + } |
| 125 | + } |
| 126 | +} |
| 127 | + |
| 128 | +// Reduce each boolean values into a Mask that indicates which elements in the |
| 129 | +// ListArray contain the matching value. |
| 130 | +fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>( |
| 131 | + ends: &[T], |
| 132 | + matches: &BooleanBuffer, |
| 133 | + validity: Validity, |
| 134 | +) -> ArrayRef { |
| 135 | + let mask: BooleanBuffer = ends |
| 136 | + .windows(2) |
| 137 | + .map(|window| { |
| 138 | + let len = window[1].as_() - window[0].as_(); |
| 139 | + let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len); |
| 140 | + set_bits.next().is_some() |
| 141 | + }) |
| 142 | + .collect(); |
| 143 | + |
| 144 | + BoolArray::new(mask, validity).into_array() |
| 145 | +} |
| 146 | + |
| 147 | +/// Returns a new array of `u64` representing the length of each list element. |
| 148 | +/// |
| 149 | +/// ## Example |
| 150 | +/// |
| 151 | +/// ```rust |
| 152 | +/// use vortex_array::arrays::{ListArray, VarBinArray}; |
| 153 | +/// use vortex_array::{Array, IntoArray}; |
| 154 | +/// use vortex_array::compute::{list_elem_len, scalar_at}; |
| 155 | +/// use vortex_array::validity::Validity; |
| 156 | +/// use vortex_buffer::buffer; |
| 157 | +/// use vortex_dtype::DType; |
| 158 | +/// |
| 159 | +/// let elements = VarBinArray::from_vec( |
| 160 | +/// vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array(); |
| 161 | +/// let offsets = buffer![0u32, 1, 3, 5].into_array(); |
| 162 | +/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap(); |
| 163 | +/// |
| 164 | +/// let lens = list_elem_len(&list_array).unwrap(); |
| 165 | +/// assert_eq!(scalar_at(&lens, 0).unwrap(), 1u32.into()); |
| 166 | +/// assert_eq!(scalar_at(&lens, 1).unwrap(), 2u32.into()); |
| 167 | +/// assert_eq!(scalar_at(&lens, 2).unwrap(), 2u32.into()); |
| 168 | +/// ``` |
| 169 | +pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> { |
| 170 | + let Some(list_array) = array.as_any().downcast_ref::<ListArray>() else { |
| 171 | + vortex_bail!("array must be of List type") |
| 172 | + }; |
| 173 | + |
| 174 | + let offsets = list_array.offsets().to_primitive()?; |
| 175 | + let lens_array = match_each_integer_ptype!(offsets.ptype(), |$T| { |
| 176 | + element_lens(offsets.as_slice::<$T>()).into_array() |
| 177 | + }); |
| 178 | + |
| 179 | + Ok(lens_array) |
| 180 | +} |
| 181 | + |
| 182 | +fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> { |
| 183 | + values |
| 184 | + .windows(2) |
| 185 | + .map(|window| window[1] - window[0]) |
| 186 | + .collect() |
| 187 | +} |
| 188 | + |
| 189 | +#[cfg(test)] |
| 190 | +mod tests { |
| 191 | + use std::sync::Arc; |
| 192 | + |
| 193 | + use itertools::Itertools; |
| 194 | + use rstest::rstest; |
| 195 | + use vortex_buffer::Buffer; |
| 196 | + use vortex_dtype::{DType, Nullability}; |
| 197 | + use vortex_scalar::Scalar; |
| 198 | + |
| 199 | + use crate::array::IntoArray; |
| 200 | + use crate::arrays::{BoolArray, ListArray, VarBinArray}; |
| 201 | + use crate::canonical::ToCanonical; |
| 202 | + use crate::compute::list_contains; |
| 203 | + use crate::validity::Validity; |
| 204 | + use crate::{Array, ArrayRef}; |
| 205 | + |
| 206 | + fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef { |
| 207 | + ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable))) |
| 208 | + .unwrap() |
| 209 | + .into_array() |
| 210 | + } |
| 211 | + |
| 212 | + fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef { |
| 213 | + let elements = values.iter().flatten().cloned().collect_vec(); |
| 214 | + let mut offsets = values |
| 215 | + .iter() |
| 216 | + .scan(0u64, |st, v| { |
| 217 | + *st += v.len() as u64; |
| 218 | + Some(*st) |
| 219 | + }) |
| 220 | + .collect_vec(); |
| 221 | + offsets.insert(0, 0u64); |
| 222 | + let offsets = Buffer::from_iter(offsets).into_array(); |
| 223 | + |
| 224 | + let elements = |
| 225 | + VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array(); |
| 226 | + |
| 227 | + ListArray::try_new(elements, offsets, Validity::NonNullable) |
| 228 | + .unwrap() |
| 229 | + .into_array() |
| 230 | + } |
| 231 | + |
| 232 | + fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray { |
| 233 | + let validity = match validity { |
| 234 | + None => Validity::NonNullable, |
| 235 | + Some(v) => Validity::from_iter(v), |
| 236 | + }; |
| 237 | + |
| 238 | + BoolArray::new(values.into_iter().collect(), validity) |
| 239 | + } |
| 240 | + |
| 241 | + #[rstest] |
| 242 | + // Case 1: list(utf8) |
| 243 | + #[case( |
| 244 | + nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]), |
| 245 | + Some("a"), |
| 246 | + bool_array(vec![false, true, true], None) |
| 247 | + )] |
| 248 | + // Case 2: list(utf8?) with NULL search scalar |
| 249 | + #[case( |
| 250 | + null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]), |
| 251 | + None, |
| 252 | + bool_array(vec![false, true, true], None) |
| 253 | + )] |
| 254 | + fn test_contains_nullable( |
| 255 | + #[case] list_array: ArrayRef, |
| 256 | + #[case] value: Option<&str>, |
| 257 | + #[case] expected: BoolArray, |
| 258 | + ) { |
| 259 | + let element_nullability = list_array.dtype().as_list_element().unwrap().nullability(); |
| 260 | + let scalar = match value { |
| 261 | + None => Scalar::null(DType::Utf8(Nullability::Nullable)), |
| 262 | + Some(v) => Scalar::utf8(v, element_nullability), |
| 263 | + }; |
| 264 | + let result = list_contains(&list_array, scalar).expect("list_contains failed"); |
| 265 | + let bool_result = result.to_bool().expect("to_bool failed"); |
| 266 | + assert_eq!( |
| 267 | + bool_result.boolean_buffer().iter().collect_vec(), |
| 268 | + expected.boolean_buffer().iter().collect_vec() |
| 269 | + ); |
| 270 | + assert_eq!(bool_result.validity(), expected.validity()); |
| 271 | + } |
| 272 | +} |
0 commit comments