Skip to content

Commit f3f663c

Browse files
authored
Port BetweenFn to BetweenKernel (#3146)
1 parent 965891e commit f3f663c

File tree

12 files changed

+477
-378
lines changed

12 files changed

+477
-378
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
use std::fmt::Debug;
2+
3+
use vortex_array::arrays::ConstantArray;
4+
use vortex_array::compute::{
5+
BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison, between,
6+
};
7+
use vortex_array::variants::PrimitiveArrayTrait;
8+
use vortex_array::{Array, ArrayRef, register_kernel};
9+
use vortex_dtype::{NativePType, Nullability};
10+
use vortex_error::VortexResult;
11+
use vortex_scalar::{Scalar, ScalarType};
12+
13+
use crate::{ALPArray, ALPEncoding, ALPFloat, match_each_alp_float_ptype};
14+
15+
impl BetweenKernel for ALPEncoding {
16+
fn between(
17+
&self,
18+
array: &ALPArray,
19+
lower: &dyn Array,
20+
upper: &dyn Array,
21+
options: &BetweenOptions,
22+
) -> VortexResult<Option<ArrayRef>> {
23+
let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24+
return Ok(None);
25+
};
26+
27+
if array.patches().is_some() {
28+
return Ok(None);
29+
}
30+
31+
let nullability =
32+
array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
33+
34+
match_each_alp_float_ptype!(array.ptype(), |$F| {
35+
between_impl::<$F>(array, $F::try_from(lower)?, $F::try_from(upper)?, nullability, options)
36+
})
37+
.map(Some)
38+
}
39+
}
40+
41+
register_kernel!(BetweenKernelAdapter(ALPEncoding).lift());
42+
43+
fn between_impl<T: NativePType + ALPFloat>(
44+
array: &ALPArray,
45+
lower: T,
46+
upper: T,
47+
nullability: Nullability,
48+
options: &BetweenOptions,
49+
) -> VortexResult<ArrayRef>
50+
where
51+
Scalar: From<T::ALPInt>,
52+
<T as ALPFloat>::ALPInt: ScalarType + Debug,
53+
{
54+
let exponents = array.exponents();
55+
56+
// There are always compared
57+
// the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
58+
// in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
59+
// and we encode into value below enc_below(value) < value < x, in which case the comparison
60+
// becomes enc(value) < x. See `alp_scalar_compare` for more details.
61+
// note that if the value doesn't encode than value != x, so must use strict comparison.
62+
let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
63+
.map(|x| (x, options.lower_strict))
64+
.unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
65+
66+
// the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
67+
let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
68+
.map(|x| (x, options.upper_strict))
69+
.unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
70+
71+
let options = BetweenOptions {
72+
lower_strict,
73+
upper_strict,
74+
};
75+
76+
between(
77+
array.encoded(),
78+
&ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()),
79+
&ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()),
80+
&options,
81+
)
82+
}
83+
84+
#[cfg(test)]
85+
mod tests {
86+
use itertools::Itertools;
87+
use vortex_array::ToCanonical;
88+
use vortex_array::arrays::PrimitiveArray;
89+
use vortex_array::compute::{BetweenOptions, StrictComparison};
90+
use vortex_dtype::Nullability;
91+
92+
use crate::alp::compute::between::between_impl;
93+
use crate::{ALPArray, alp_encode};
94+
95+
fn between_test(arr: &ALPArray, lower: f32, upper: f32, options: &BetweenOptions) -> bool {
96+
let res = between_impl(arr, lower, upper, Nullability::Nullable, options)
97+
.unwrap()
98+
.to_bool()
99+
.unwrap()
100+
.boolean_buffer()
101+
.iter()
102+
.collect_vec();
103+
assert_eq!(res.len(), 1);
104+
105+
res[0]
106+
}
107+
108+
#[test]
109+
fn comparison_range() {
110+
let value = 0.0605_f32;
111+
let array = PrimitiveArray::from_iter([value; 1]);
112+
let encoded = alp_encode(&array, None).unwrap();
113+
assert!(encoded.patches().is_none());
114+
assert_eq!(
115+
encoded.encoded().to_primitive().unwrap().as_slice::<i32>(),
116+
vec![605; 1]
117+
);
118+
119+
assert!(between_test(
120+
&encoded,
121+
0.0605_f32,
122+
0.0605,
123+
&BetweenOptions {
124+
lower_strict: StrictComparison::NonStrict,
125+
upper_strict: StrictComparison::NonStrict,
126+
},
127+
));
128+
129+
assert!(!between_test(
130+
&encoded,
131+
0.0605_f32,
132+
0.0605,
133+
&BetweenOptions {
134+
lower_strict: StrictComparison::Strict,
135+
upper_strict: StrictComparison::NonStrict,
136+
},
137+
));
138+
139+
assert!(!between_test(
140+
&encoded,
141+
0.0605_f32,
142+
0.0605,
143+
&BetweenOptions {
144+
lower_strict: StrictComparison::NonStrict,
145+
upper_strict: StrictComparison::Strict,
146+
},
147+
));
148+
149+
assert!(between_test(
150+
&encoded,
151+
0.060499_f32,
152+
0.06051,
153+
&BetweenOptions {
154+
lower_strict: StrictComparison::NonStrict,
155+
upper_strict: StrictComparison::NonStrict,
156+
},
157+
));
158+
159+
assert!(between_test(
160+
&encoded,
161+
0.06_f32,
162+
0.06051,
163+
&BetweenOptions {
164+
lower_strict: StrictComparison::NonStrict,
165+
upper_strict: StrictComparison::Strict,
166+
},
167+
))
168+
}
169+
}
Lines changed: 3 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
1+
mod between;
12
mod compare;
23
mod filter;
34

4-
use std::fmt::Debug;
5-
6-
use vortex_array::arrays::ConstantArray;
7-
use vortex_array::compute::{
8-
BetweenFn, BetweenOptions, ScalarAtFn, SliceFn, StrictComparison, TakeFn, between, scalar_at,
9-
slice, take,
10-
};
5+
use vortex_array::compute::{ScalarAtFn, SliceFn, TakeFn, scalar_at, slice, take};
116
use vortex_array::variants::PrimitiveArrayTrait;
127
use vortex_array::vtable::ComputeVTable;
138
use vortex_array::{Array, ArrayRef};
14-
use vortex_dtype::{NativePType, Nullability};
159
use vortex_error::VortexResult;
16-
use vortex_scalar::{Scalar, ScalarType};
10+
use vortex_scalar::Scalar;
1711

1812
use crate::{ALPArray, ALPEncoding, ALPFloat, match_each_alp_float_ptype};
1913

2014
impl ComputeVTable for ALPEncoding {
21-
fn between_fn(&self) -> Option<&dyn BetweenFn<&dyn Array>> {
22-
Some(self)
23-
}
24-
2515
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
2616
Some(self)
2717
}
@@ -93,157 +83,3 @@ impl SliceFn<&ALPArray> for ALPEncoding {
9383
.into_array())
9484
}
9585
}
96-
97-
impl BetweenFn<&ALPArray> for ALPEncoding {
98-
fn between(
99-
&self,
100-
array: &ALPArray,
101-
lower: &dyn Array,
102-
upper: &dyn Array,
103-
options: &BetweenOptions,
104-
) -> VortexResult<Option<ArrayRef>> {
105-
let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
106-
return Ok(None);
107-
};
108-
109-
if array.patches().is_some() {
110-
return Ok(None);
111-
}
112-
113-
let nullability =
114-
array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
115-
116-
match_each_alp_float_ptype!(array.ptype(), |$F| {
117-
between_impl::<$F>(array, $F::try_from(lower)?, $F::try_from(upper)?, nullability, options)
118-
})
119-
.map(Some)
120-
}
121-
}
122-
123-
fn between_impl<T: NativePType + ALPFloat>(
124-
array: &ALPArray,
125-
lower: T,
126-
upper: T,
127-
nullability: Nullability,
128-
options: &BetweenOptions,
129-
) -> VortexResult<ArrayRef>
130-
where
131-
Scalar: From<T::ALPInt>,
132-
<T as ALPFloat>::ALPInt: ScalarType + Debug,
133-
{
134-
let exponents = array.exponents();
135-
136-
// There are always compared
137-
// the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
138-
// in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
139-
// and we encode into value below enc_below(value) < value < x, in which case the comparison
140-
// becomes enc(value) < x. See `alp_scalar_compare` for more details.
141-
// note that if the value doesn't encode than value != x, so must use strict comparison.
142-
let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
143-
.map(|x| (x, options.lower_strict))
144-
.unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
145-
146-
// the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
147-
let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
148-
.map(|x| (x, options.upper_strict))
149-
.unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
150-
151-
let options = BetweenOptions {
152-
lower_strict,
153-
upper_strict,
154-
};
155-
156-
between(
157-
array.encoded(),
158-
&ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()),
159-
&ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()),
160-
&options,
161-
)
162-
}
163-
164-
#[cfg(test)]
165-
mod tests {
166-
use itertools::Itertools;
167-
use vortex_array::ToCanonical;
168-
use vortex_array::arrays::PrimitiveArray;
169-
use vortex_array::compute::{BetweenOptions, StrictComparison};
170-
use vortex_dtype::Nullability;
171-
172-
use crate::alp::compute::between_impl;
173-
use crate::{ALPArray, alp_encode};
174-
175-
fn between_test(arr: &ALPArray, lower: f32, upper: f32, options: &BetweenOptions) -> bool {
176-
let res = between_impl(arr, lower, upper, Nullability::Nullable, options)
177-
.unwrap()
178-
.to_bool()
179-
.unwrap()
180-
.boolean_buffer()
181-
.iter()
182-
.collect_vec();
183-
assert_eq!(res.len(), 1);
184-
185-
res[0]
186-
}
187-
188-
#[test]
189-
fn comparison_range() {
190-
let value = 0.0605_f32;
191-
let array = PrimitiveArray::from_iter([value; 1]);
192-
let encoded = alp_encode(&array, None).unwrap();
193-
assert!(encoded.patches().is_none());
194-
assert_eq!(
195-
encoded.encoded().to_primitive().unwrap().as_slice::<i32>(),
196-
vec![605; 1]
197-
);
198-
199-
assert!(between_test(
200-
&encoded,
201-
0.0605_f32,
202-
0.0605,
203-
&BetweenOptions {
204-
lower_strict: StrictComparison::NonStrict,
205-
upper_strict: StrictComparison::NonStrict,
206-
},
207-
));
208-
209-
assert!(!between_test(
210-
&encoded,
211-
0.0605_f32,
212-
0.0605,
213-
&BetweenOptions {
214-
lower_strict: StrictComparison::Strict,
215-
upper_strict: StrictComparison::NonStrict,
216-
},
217-
));
218-
219-
assert!(!between_test(
220-
&encoded,
221-
0.0605_f32,
222-
0.0605,
223-
&BetweenOptions {
224-
lower_strict: StrictComparison::NonStrict,
225-
upper_strict: StrictComparison::Strict,
226-
},
227-
));
228-
229-
assert!(between_test(
230-
&encoded,
231-
0.060499_f32,
232-
0.06051,
233-
&BetweenOptions {
234-
lower_strict: StrictComparison::NonStrict,
235-
upper_strict: StrictComparison::NonStrict,
236-
},
237-
));
238-
239-
assert!(between_test(
240-
&encoded,
241-
0.06_f32,
242-
0.06051,
243-
&BetweenOptions {
244-
lower_strict: StrictComparison::NonStrict,
245-
upper_strict: StrictComparison::Strict,
246-
},
247-
))
248-
}
249-
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use vortex_array::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, between};
2+
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
3+
use vortex_error::VortexResult;
4+
5+
use crate::{BitPackedArray, BitPackedEncoding};
6+
7+
impl BetweenKernel for BitPackedEncoding {
8+
fn between(
9+
&self,
10+
array: &BitPackedArray,
11+
lower: &dyn Array,
12+
upper: &dyn Array,
13+
options: &BetweenOptions,
14+
) -> VortexResult<Option<ArrayRef>> {
15+
if !lower.is_constant() || !upper.is_constant() {
16+
return Ok(None);
17+
};
18+
19+
between(
20+
&array.clone().to_canonical()?.into_array(),
21+
lower,
22+
upper,
23+
options,
24+
)
25+
.map(Some)
26+
}
27+
}
28+
29+
register_kernel!(BetweenKernelAdapter(BitPackedEncoding).lift());

0 commit comments

Comments
 (0)