|
1 | | -use vortex_error::VortexResult; |
| 1 | +use std::ops::AddAssign; |
| 2 | + |
| 3 | +use arrow_buffer::BooleanBufferBuilder; |
| 4 | +use num_traits::AsPrimitive; |
| 5 | +use vortex_buffer::BufferMut; |
| 6 | +use vortex_dtype::{NativePType, Nullability, match_each_integer_ptype}; |
| 7 | +use vortex_error::{VortexExpect, VortexResult}; |
2 | 8 | use vortex_mask::Mask; |
| 9 | +use vortex_scalar::Scalar; |
3 | 10 |
|
4 | | -use crate::arrays::ListVTable; |
5 | | -use crate::compute::{FilterKernel, FilterKernelAdapter, arrow_filter_fn}; |
6 | | -use crate::{ArrayRef, register_kernel}; |
| 11 | +use crate::arrays::{ConstantArray, ListArray, ListVTable, PrimitiveArray}; |
| 12 | +use crate::compute::{FilterKernel, FilterKernelAdapter, arrow_filter_fn, filter}; |
| 13 | +use crate::validity::Validity; |
| 14 | +use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel}; |
7 | 15 |
|
8 | 16 | impl FilterKernel for ListVTable { |
9 | 17 | fn filter(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef> { |
10 | | - arrow_filter_fn(array.as_ref(), mask) |
| 18 | + let offsets = array.offsets.to_primitive()?; |
| 19 | + |
| 20 | + match array.validity_mask()? { |
| 21 | + Mask::AllTrue(_) => { |
| 22 | + match_each_integer_ptype!(offsets.ptype(), |$I| { |
| 23 | + filter_all_valid::<$I>( |
| 24 | + offsets.as_slice::<$I>(), |
| 25 | + array.elements().as_ref(), |
| 26 | + mask, |
| 27 | + array.dtype().nullability(), |
| 28 | + ) |
| 29 | + }) |
| 30 | + } |
| 31 | + Mask::AllFalse(_) => { |
| 32 | + // If all array offsets are null, then the array is simply null? |
| 33 | + Ok( |
| 34 | + ConstantArray::new(Scalar::null(array.dtype().clone()), mask.true_count()) |
| 35 | + .into_array(), |
| 36 | + ) |
| 37 | + } |
| 38 | + Mask::Values(_) => { |
| 39 | + // TODO(ngates): implemented null filtering |
| 40 | + arrow_filter_fn(array.as_ref(), mask) |
| 41 | + } |
| 42 | + } |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +fn filter_all_valid<I: NativePType + AsPrimitive<usize> + AddAssign>( |
| 47 | + offsets: &[I], |
| 48 | + elements: &dyn Array, |
| 49 | + mask: &Mask, |
| 50 | + nullability: Nullability, |
| 51 | +) -> VortexResult<ArrayRef> { |
| 52 | + // We compute a new set of offsets, as well as a mask for filtering the elements. |
| 53 | + let mut new_offsets = BufferMut::<I>::with_capacity(mask.true_count() + 1); |
| 54 | + new_offsets.push(I::zero()); |
| 55 | + let mut new_offset: I = I::zero(); |
| 56 | + |
| 57 | + let mut mask_builder = BooleanBufferBuilder::new(elements.len()); |
| 58 | + for &(start, end) in mask |
| 59 | + .values() |
| 60 | + .vortex_expect("all true and all false are handled by filter entry point") |
| 61 | + .slices() |
| 62 | + { |
| 63 | + let elem_start: usize = offsets[start].as_(); |
| 64 | + let elem_end: usize = offsets[end].as_(); |
| 65 | + let elem_len = elem_end - elem_start; |
| 66 | + mask_builder.append_n(elem_start - mask_builder.len(), false); |
| 67 | + mask_builder.append_n(elem_len, true); |
| 68 | + |
| 69 | + // Add each of the new offsets into the result |
| 70 | + for i in start..end { |
| 71 | + let elem_len = offsets[i + 1] - offsets[i]; |
| 72 | + new_offset += elem_len; |
| 73 | + new_offsets.push(new_offset); |
| 74 | + } |
11 | 75 | } |
| 76 | + mask_builder.append_n(elements.len() - mask_builder.len(), false); |
| 77 | + |
| 78 | + let new_elements = filter(elements, &Mask::from_buffer(mask_builder.finish()))?; |
| 79 | + |
| 80 | + let new_offsets = PrimitiveArray::new(new_offsets, Validity::NonNullable).into_array(); |
| 81 | + |
| 82 | + Ok(ListArray::try_new(new_elements, new_offsets, Validity::from(nullability))?.into_array()) |
12 | 83 | } |
13 | 84 |
|
14 | 85 | register_kernel!(FilterKernelAdapter(ListVTable).lift()); |
0 commit comments