Skip to content

Commit a1d5349

Browse files
authored
Remove primitive compare impl (#1337)
And just delegate to Arrow compute instead. I'm curious if our implementation did anything different? I don't think so...
1 parent d3a28f4 commit a1d5349

File tree

3 files changed

+29
-179
lines changed

3 files changed

+29
-179
lines changed
Lines changed: 17 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,194 +1,34 @@
1-
use arrow_buffer::bit_util::ceil;
2-
use arrow_buffer::{BooleanBuffer, MutableBuffer};
3-
use vortex_dtype::{match_each_native_ptype, NativePType};
4-
use vortex_error::{vortex_err, VortexExpect, VortexResult};
5-
use vortex_scalar::PrimitiveScalar;
1+
use vortex_error::VortexResult;
62

73
use crate::array::primitive::PrimitiveArray;
8-
use crate::array::{BoolArray, ConstantArray};
9-
use crate::compute::{MaybeCompareFn, Operator};
10-
use crate::variants::PrimitiveArrayTrait;
11-
use crate::{ArrayDType, ArrayData, IntoArrayData};
4+
use crate::array::ConstantArray;
5+
use crate::compute::{arrow_compare, MaybeCompareFn, Operator};
6+
use crate::stats::{ArrayStatistics, Stat};
7+
use crate::ArrayData;
128

139
impl MaybeCompareFn for PrimitiveArray {
1410
fn maybe_compare(
1511
&self,
1612
other: &ArrayData,
1713
operator: Operator,
1814
) -> Option<VortexResult<ArrayData>> {
19-
if let Ok(const_array) = ConstantArray::try_from(other) {
20-
return Some(primitive_const_compare(self, const_array, operator));
15+
// If the RHS is constant, then delegate to Arrow since.
16+
// TODO(ngates): remove these dual checks once we make stats not a hashmap
17+
// https://github.com/spiraldb/vortex/issues/1309
18+
if ConstantArray::try_from(other).is_ok()
19+
|| other
20+
.statistics()
21+
.get_as::<bool>(Stat::IsConstant)
22+
.unwrap_or(false)
23+
{
24+
return Some(arrow_compare(self.as_ref(), other, operator));
2125
}
2226

27+
// If the RHS is primitive, then delegate to Arrow.
2328
if let Ok(primitive) = PrimitiveArray::try_from(other) {
24-
let match_mask = match_each_native_ptype!(self.ptype(), |$T| {
25-
apply_predicate(self.maybe_null_slice::<$T>(), primitive.maybe_null_slice::<$T>(), operator.to_fn::<$T>())
26-
});
27-
28-
let validity = self
29-
.validity()
30-
.and(primitive.validity())
31-
.map(|v| v.into_nullable());
32-
33-
return Some(
34-
validity
35-
.and_then(|v| BoolArray::try_new(match_mask, v))
36-
.map(|a| a.into_array()),
37-
);
29+
return Some(arrow_compare(self.as_ref(), primitive.as_ref(), operator));
3830
}
3931

4032
None
4133
}
4234
}
43-
44-
fn primitive_const_compare(
45-
this: &PrimitiveArray,
46-
other: ConstantArray,
47-
operator: Operator,
48-
) -> VortexResult<ArrayData> {
49-
let primitive_scalar = PrimitiveScalar::try_new(other.dtype(), other.scalar_value())
50-
.vortex_expect("Expected a primitive scalar");
51-
52-
let buffer = match_each_native_ptype!(this.ptype(), |$T| {
53-
let typed_value = primitive_scalar.typed_value::<$T>()
54-
.ok_or_else(|| vortex_err!("Type mismatch between array and constant"))?;
55-
primitive_value_compare::<$T>(this, typed_value, operator)
56-
});
57-
58-
Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array())
59-
}
60-
61-
fn primitive_value_compare<T: NativePType>(
62-
this: &PrimitiveArray,
63-
value: T,
64-
op: Operator,
65-
) -> BooleanBuffer {
66-
let op_fn = op.to_fn::<T>();
67-
let slice = this.maybe_null_slice::<T>();
68-
69-
BooleanBuffer::collect_bool(this.len(), |idx| {
70-
op_fn(unsafe { *slice.get_unchecked(idx) }, value)
71-
})
72-
}
73-
74-
fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
75-
lhs: &[T],
76-
rhs: &[T],
77-
f: F,
78-
) -> BooleanBuffer {
79-
const BLOCK_SIZE: usize = u64::BITS as usize;
80-
81-
let len = lhs.len();
82-
let reminder = len % BLOCK_SIZE;
83-
let block_count = len / BLOCK_SIZE;
84-
85-
let mut buffer = MutableBuffer::new(ceil(len, BLOCK_SIZE) * 8);
86-
87-
for block in 0..block_count {
88-
let mut packed_block = 0_u64;
89-
for bit_idx in 0..BLOCK_SIZE {
90-
let idx = bit_idx + block * BLOCK_SIZE;
91-
let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe {
92-
*rhs.get_unchecked(idx)
93-
});
94-
packed_block |= (r as u64) << bit_idx;
95-
}
96-
97-
unsafe {
98-
buffer.push_unchecked(packed_block);
99-
}
100-
}
101-
102-
if reminder != 0 {
103-
let mut packed_block = 0_u64;
104-
for bit_idx in 0..reminder {
105-
let idx = bit_idx + block_count * BLOCK_SIZE;
106-
let r = f(lhs[idx], rhs[idx]);
107-
packed_block |= (r as u64) << bit_idx;
108-
}
109-
110-
unsafe {
111-
buffer.push_unchecked(packed_block);
112-
}
113-
}
114-
115-
BooleanBuffer::new(buffer.into(), 0, len)
116-
}
117-
118-
#[cfg(test)]
119-
#[allow(clippy::panic_in_result_fn)]
120-
mod test {
121-
use itertools::Itertools;
122-
123-
use super::*;
124-
use crate::compute::compare;
125-
use crate::IntoArrayVariant;
126-
127-
fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
128-
let filtered = indices_bits
129-
.boolean_buffer()
130-
.iter()
131-
.enumerate()
132-
.filter_map(|(idx, v)| {
133-
let valid_and_true = indices_bits.validity().is_valid(idx) & v;
134-
valid_and_true.then_some(idx as u64)
135-
})
136-
.collect_vec();
137-
filtered
138-
}
139-
140-
#[test]
141-
fn test_basic_comparisons() -> VortexResult<()> {
142-
let arr = PrimitiveArray::from_nullable_vec(vec![
143-
Some(1i32),
144-
Some(2),
145-
Some(3),
146-
Some(4),
147-
None,
148-
Some(5),
149-
Some(6),
150-
Some(7),
151-
Some(8),
152-
None,
153-
Some(9),
154-
None,
155-
])
156-
.into_array();
157-
158-
let matches = compare(&arr, &arr, Operator::Eq)?.into_bool()?;
159-
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
160-
161-
let matches = compare(&arr, &arr, Operator::NotEq)?.into_bool()?;
162-
let empty: [u64; 0] = [];
163-
assert_eq!(to_int_indices(matches), empty);
164-
165-
let other = PrimitiveArray::from_nullable_vec(vec![
166-
Some(1i32),
167-
Some(2),
168-
Some(3),
169-
Some(4),
170-
None,
171-
Some(6),
172-
Some(7),
173-
Some(8),
174-
Some(9),
175-
None,
176-
Some(10),
177-
None,
178-
])
179-
.into_array();
180-
181-
let matches = compare(&arr, &other, Operator::Lte)?.into_bool()?;
182-
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
183-
184-
let matches = compare(&arr, &other, Operator::Lt)?.into_bool()?;
185-
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
186-
187-
let matches = compare(&other, &arr, Operator::Gte)?.into_bool()?;
188-
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
189-
190-
let matches = compare(&other, &arr, Operator::Gt)?.into_bool()?;
191-
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
192-
Ok(())
193-
}
194-
}

vortex-array/src/compute/compare.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,17 @@ pub fn compare(
127127
}
128128

129129
// Fallback to arrow on canonical types
130-
let lhs = Datum::try_from(left.clone())?;
131-
let rhs = Datum::try_from(right.clone())?;
130+
arrow_compare(left, right, operator)
131+
}
132+
133+
/// Implementation of `CompareFn` using the Arrow crate.
134+
pub(crate) fn arrow_compare(
135+
lhs: &ArrayData,
136+
rhs: &ArrayData,
137+
operator: Operator,
138+
) -> VortexResult<ArrayData> {
139+
let lhs = Datum::try_from(lhs.clone())?;
140+
let rhs = Datum::try_from(rhs.clone())?;
132141

133142
let array = match operator {
134143
Operator::Eq => cmp::eq(&lhs, &rhs)?,

vortex-array/src/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//! from Arrow.
99
1010
pub use boolean::{and, and_kleene, or, or_kleene, AndFn, OrFn};
11+
pub(crate) use compare::arrow_compare;
1112
pub use compare::{compare, scalar_cmp, CompareFn, MaybeCompareFn, Operator};
1213
pub use filter::{filter, FilterFn};
1314
pub use search_sorted::*;

0 commit comments

Comments
 (0)