|
1 | | -use arrow_buffer::bit_util::ceil; |
2 | | -use arrow_buffer::{BooleanBuffer, MutableBuffer}; |
3 | | -use vortex_dtype::{match_each_native_ptype, NativePType}; |
4 | | -use vortex_error::{vortex_err, VortexExpect, VortexResult}; |
5 | | -use vortex_scalar::PrimitiveScalar; |
| 1 | +use vortex_error::VortexResult; |
6 | 2 |
|
7 | 3 | use crate::array::primitive::PrimitiveArray; |
8 | | -use crate::array::{BoolArray, ConstantArray}; |
9 | | -use crate::compute::{MaybeCompareFn, Operator}; |
10 | | -use crate::variants::PrimitiveArrayTrait; |
11 | | -use crate::{ArrayDType, ArrayData, IntoArrayData}; |
| 4 | +use crate::array::ConstantArray; |
| 5 | +use crate::compute::{arrow_compare, MaybeCompareFn, Operator}; |
| 6 | +use crate::stats::{ArrayStatistics, Stat}; |
| 7 | +use crate::ArrayData; |
12 | 8 |
|
13 | 9 | impl MaybeCompareFn for PrimitiveArray { |
14 | 10 | fn maybe_compare( |
15 | 11 | &self, |
16 | 12 | other: &ArrayData, |
17 | 13 | operator: Operator, |
18 | 14 | ) -> Option<VortexResult<ArrayData>> { |
19 | | - if let Ok(const_array) = ConstantArray::try_from(other) { |
20 | | - return Some(primitive_const_compare(self, const_array, operator)); |
| 15 | + // If the RHS is constant, then delegate to Arrow since. |
| 16 | + // TODO(ngates): remove these dual checks once we make stats not a hashmap |
| 17 | + // https://github.com/spiraldb/vortex/issues/1309 |
| 18 | + if ConstantArray::try_from(other).is_ok() |
| 19 | + || other |
| 20 | + .statistics() |
| 21 | + .get_as::<bool>(Stat::IsConstant) |
| 22 | + .unwrap_or(false) |
| 23 | + { |
| 24 | + return Some(arrow_compare(self.as_ref(), other, operator)); |
21 | 25 | } |
22 | 26 |
|
| 27 | + // If the RHS is primitive, then delegate to Arrow. |
23 | 28 | if let Ok(primitive) = PrimitiveArray::try_from(other) { |
24 | | - let match_mask = match_each_native_ptype!(self.ptype(), |$T| { |
25 | | - apply_predicate(self.maybe_null_slice::<$T>(), primitive.maybe_null_slice::<$T>(), operator.to_fn::<$T>()) |
26 | | - }); |
27 | | - |
28 | | - let validity = self |
29 | | - .validity() |
30 | | - .and(primitive.validity()) |
31 | | - .map(|v| v.into_nullable()); |
32 | | - |
33 | | - return Some( |
34 | | - validity |
35 | | - .and_then(|v| BoolArray::try_new(match_mask, v)) |
36 | | - .map(|a| a.into_array()), |
37 | | - ); |
| 29 | + return Some(arrow_compare(self.as_ref(), primitive.as_ref(), operator)); |
38 | 30 | } |
39 | 31 |
|
40 | 32 | None |
41 | 33 | } |
42 | 34 | } |
43 | | - |
44 | | -fn primitive_const_compare( |
45 | | - this: &PrimitiveArray, |
46 | | - other: ConstantArray, |
47 | | - operator: Operator, |
48 | | -) -> VortexResult<ArrayData> { |
49 | | - let primitive_scalar = PrimitiveScalar::try_new(other.dtype(), other.scalar_value()) |
50 | | - .vortex_expect("Expected a primitive scalar"); |
51 | | - |
52 | | - let buffer = match_each_native_ptype!(this.ptype(), |$T| { |
53 | | - let typed_value = primitive_scalar.typed_value::<$T>() |
54 | | - .ok_or_else(|| vortex_err!("Type mismatch between array and constant"))?; |
55 | | - primitive_value_compare::<$T>(this, typed_value, operator) |
56 | | - }); |
57 | | - |
58 | | - Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array()) |
59 | | -} |
60 | | - |
61 | | -fn primitive_value_compare<T: NativePType>( |
62 | | - this: &PrimitiveArray, |
63 | | - value: T, |
64 | | - op: Operator, |
65 | | -) -> BooleanBuffer { |
66 | | - let op_fn = op.to_fn::<T>(); |
67 | | - let slice = this.maybe_null_slice::<T>(); |
68 | | - |
69 | | - BooleanBuffer::collect_bool(this.len(), |idx| { |
70 | | - op_fn(unsafe { *slice.get_unchecked(idx) }, value) |
71 | | - }) |
72 | | -} |
73 | | - |
74 | | -fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>( |
75 | | - lhs: &[T], |
76 | | - rhs: &[T], |
77 | | - f: F, |
78 | | -) -> BooleanBuffer { |
79 | | - const BLOCK_SIZE: usize = u64::BITS as usize; |
80 | | - |
81 | | - let len = lhs.len(); |
82 | | - let reminder = len % BLOCK_SIZE; |
83 | | - let block_count = len / BLOCK_SIZE; |
84 | | - |
85 | | - let mut buffer = MutableBuffer::new(ceil(len, BLOCK_SIZE) * 8); |
86 | | - |
87 | | - for block in 0..block_count { |
88 | | - let mut packed_block = 0_u64; |
89 | | - for bit_idx in 0..BLOCK_SIZE { |
90 | | - let idx = bit_idx + block * BLOCK_SIZE; |
91 | | - let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe { |
92 | | - *rhs.get_unchecked(idx) |
93 | | - }); |
94 | | - packed_block |= (r as u64) << bit_idx; |
95 | | - } |
96 | | - |
97 | | - unsafe { |
98 | | - buffer.push_unchecked(packed_block); |
99 | | - } |
100 | | - } |
101 | | - |
102 | | - if reminder != 0 { |
103 | | - let mut packed_block = 0_u64; |
104 | | - for bit_idx in 0..reminder { |
105 | | - let idx = bit_idx + block_count * BLOCK_SIZE; |
106 | | - let r = f(lhs[idx], rhs[idx]); |
107 | | - packed_block |= (r as u64) << bit_idx; |
108 | | - } |
109 | | - |
110 | | - unsafe { |
111 | | - buffer.push_unchecked(packed_block); |
112 | | - } |
113 | | - } |
114 | | - |
115 | | - BooleanBuffer::new(buffer.into(), 0, len) |
116 | | -} |
117 | | - |
118 | | -#[cfg(test)] |
119 | | -#[allow(clippy::panic_in_result_fn)] |
120 | | -mod test { |
121 | | - use itertools::Itertools; |
122 | | - |
123 | | - use super::*; |
124 | | - use crate::compute::compare; |
125 | | - use crate::IntoArrayVariant; |
126 | | - |
127 | | - fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> { |
128 | | - let filtered = indices_bits |
129 | | - .boolean_buffer() |
130 | | - .iter() |
131 | | - .enumerate() |
132 | | - .filter_map(|(idx, v)| { |
133 | | - let valid_and_true = indices_bits.validity().is_valid(idx) & v; |
134 | | - valid_and_true.then_some(idx as u64) |
135 | | - }) |
136 | | - .collect_vec(); |
137 | | - filtered |
138 | | - } |
139 | | - |
140 | | - #[test] |
141 | | - fn test_basic_comparisons() -> VortexResult<()> { |
142 | | - let arr = PrimitiveArray::from_nullable_vec(vec![ |
143 | | - Some(1i32), |
144 | | - Some(2), |
145 | | - Some(3), |
146 | | - Some(4), |
147 | | - None, |
148 | | - Some(5), |
149 | | - Some(6), |
150 | | - Some(7), |
151 | | - Some(8), |
152 | | - None, |
153 | | - Some(9), |
154 | | - None, |
155 | | - ]) |
156 | | - .into_array(); |
157 | | - |
158 | | - let matches = compare(&arr, &arr, Operator::Eq)?.into_bool()?; |
159 | | - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); |
160 | | - |
161 | | - let matches = compare(&arr, &arr, Operator::NotEq)?.into_bool()?; |
162 | | - let empty: [u64; 0] = []; |
163 | | - assert_eq!(to_int_indices(matches), empty); |
164 | | - |
165 | | - let other = PrimitiveArray::from_nullable_vec(vec![ |
166 | | - Some(1i32), |
167 | | - Some(2), |
168 | | - Some(3), |
169 | | - Some(4), |
170 | | - None, |
171 | | - Some(6), |
172 | | - Some(7), |
173 | | - Some(8), |
174 | | - Some(9), |
175 | | - None, |
176 | | - Some(10), |
177 | | - None, |
178 | | - ]) |
179 | | - .into_array(); |
180 | | - |
181 | | - let matches = compare(&arr, &other, Operator::Lte)?.into_bool()?; |
182 | | - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); |
183 | | - |
184 | | - let matches = compare(&arr, &other, Operator::Lt)?.into_bool()?; |
185 | | - assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); |
186 | | - |
187 | | - let matches = compare(&other, &arr, Operator::Gte)?.into_bool()?; |
188 | | - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); |
189 | | - |
190 | | - let matches = compare(&other, &arr, Operator::Gt)?.into_bool()?; |
191 | | - assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); |
192 | | - Ok(()) |
193 | | - } |
194 | | -} |
0 commit comments