Skip to content

Commit c93f06e

Browse files
authored
Execute parent (#5736)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent a9b9d2a commit c93f06e

File tree

33 files changed

+539
-138
lines changed

33 files changed

+539
-138
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/alp/src/alp/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use vortex_array::ExecutionCtx;
1616
use vortex_array::Precision;
1717
use vortex_array::ProstMetadata;
1818
use vortex_array::SerializeMetadata;
19+
use vortex_array::VectorExecutor;
1920
use vortex_array::patches::Patches;
2021
use vortex_array::patches::PatchesMetadata;
2122
use vortex_array::serde::ArrayChildren;

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
mod compute;
5+
mod rules;
56

67
use std::hash::Hash;
78
use std::ops::Range;
@@ -48,6 +49,8 @@ use vortex_error::vortex_ensure;
4849
use vortex_scalar::DecimalValue;
4950
use vortex_scalar::Scalar;
5051

52+
use crate::decimal_byte_parts::rules::PARENT_RULES;
53+
5154
vtable!(DecimalByteParts);
5255

5356
#[derive(Clone, prost::Message)]
@@ -127,6 +130,14 @@ impl VTable for DecimalBytePartsVTable {
127130
array.msp = children.into_iter().next().vortex_expect("checked");
128131
Ok(())
129132
}
133+
134+
fn reduce_parent(
135+
array: &Self::Array,
136+
parent: &ArrayRef,
137+
child_idx: usize,
138+
) -> VortexResult<Option<ArrayRef>> {
139+
PARENT_RULES.evaluate(array, parent, child_idx)
140+
}
130141
}
131142

132143
/// This array encodes decimals as between 1-4 columns of primitive typed children.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::ArrayRef;
5+
use vortex_array::IntoArray;
6+
use vortex_array::arrays::FilterArray;
7+
use vortex_array::arrays::FilterVTable;
8+
use vortex_array::matchers::Exact;
9+
use vortex_array::optimizer::rules::ArrayParentReduceRule;
10+
use vortex_array::optimizer::rules::ParentRuleSet;
11+
use vortex_error::VortexResult;
12+
13+
use crate::DecimalBytePartsArray;
14+
use crate::DecimalBytePartsVTable;
15+
16+
pub(super) const PARENT_RULES: ParentRuleSet<DecimalBytePartsVTable> =
17+
ParentRuleSet::new(&[ParentRuleSet::lift(&DecimalBytePartsFilterPushDownRule)]);
18+
19+
#[derive(Debug)]
20+
struct DecimalBytePartsFilterPushDownRule;
21+
22+
impl ArrayParentReduceRule<DecimalBytePartsVTable> for DecimalBytePartsFilterPushDownRule {
23+
type Parent = Exact<FilterVTable>;
24+
25+
fn parent(&self) -> Self::Parent {
26+
Exact::from(&FilterVTable)
27+
}
28+
29+
fn reduce_parent(
30+
&self,
31+
child: &DecimalBytePartsArray,
32+
parent: &FilterArray,
33+
_child_idx: usize,
34+
) -> VortexResult<Option<ArrayRef>> {
35+
// TODO(ngates): we should benchmark whether to push-down filters with "lower parts".
36+
// For now, we only push down if there are no lower parts.
37+
if !child._lower_parts.is_empty() {
38+
return Ok(None);
39+
}
40+
41+
let new_msp =
42+
FilterArray::new(child.msp.clone(), parent.filter_mask().clone()).into_array();
43+
let new_child =
44+
DecimalBytePartsArray::try_new(new_msp, *child.decimal_dtype())?.into_array();
45+
Ok(Some(new_child))
46+
}
47+
}

encodings/fastlanes/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ prost = { workspace = true }
2626
rand = { workspace = true, optional = true }
2727
vortex-array = { workspace = true }
2828
vortex-buffer = { workspace = true }
29+
vortex-compute = { workspace = true }
2930
vortex-dtype = { workspace = true }
3031
vortex-error = { workspace = true }
3132
vortex-mask = { workspace = true }

encodings/fastlanes/src/delta/vtable/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use prost::Message;
99
use vortex_array::ArrayRef;
1010
use vortex_array::ExecutionCtx;
1111
use vortex_array::ProstMetadata;
12+
use vortex_array::VectorExecutor;
1213
use vortex_array::serde::ArrayChildren;
1314
use vortex_array::vtable;
1415
use vortex_array::vtable::ArrayId;

encodings/fastlanes/src/for/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ use vortex_scalar::Scalar;
2424
use vortex_scalar::ScalarValue;
2525

2626
use crate::FoRArray;
27+
use crate::r#for::vtable::rules::PARENT_RULES;
2728

2829
mod array;
2930
mod canonical;
3031
mod encode;
3132
mod operations;
33+
mod rules;
3234
mod validity;
3335
mod visitor;
3436

@@ -104,6 +106,14 @@ impl VTable for FoRVTable {
104106

105107
FoRArray::try_new(encoded, reference)
106108
}
109+
110+
fn reduce_parent(
111+
array: &Self::Array,
112+
parent: &ArrayRef,
113+
child_idx: usize,
114+
) -> VortexResult<Option<ArrayRef>> {
115+
PARENT_RULES.evaluate(array, parent, child_idx)
116+
}
107117
}
108118

109119
#[derive(Debug)]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::ArrayRef;
5+
use vortex_array::IntoArray;
6+
use vortex_array::arrays::FilterArray;
7+
use vortex_array::arrays::FilterVTable;
8+
use vortex_array::matchers::Exact;
9+
use vortex_array::optimizer::rules::ArrayParentReduceRule;
10+
use vortex_array::optimizer::rules::ParentRuleSet;
11+
use vortex_error::VortexResult;
12+
13+
use crate::FoRArray;
14+
use crate::FoRVTable;
15+
16+
pub(super) const PARENT_RULES: ParentRuleSet<FoRVTable> =
17+
ParentRuleSet::new(&[ParentRuleSet::lift(&FoRFilterPushDownRule)]);
18+
19+
#[derive(Debug)]
20+
struct FoRFilterPushDownRule;
21+
22+
impl ArrayParentReduceRule<FoRVTable> for FoRFilterPushDownRule {
23+
type Parent = Exact<FilterVTable>;
24+
25+
fn parent(&self) -> Self::Parent {
26+
Exact::from(&FilterVTable)
27+
}
28+
29+
fn reduce_parent(
30+
&self,
31+
child: &FoRArray,
32+
parent: &FilterArray,
33+
_child_idx: usize,
34+
) -> VortexResult<Option<ArrayRef>> {
35+
let new_array = unsafe {
36+
FoRArray::new_unchecked(
37+
FilterArray::new(child.encoded().clone(), parent.filter_mask().clone())
38+
.into_array(),
39+
child.reference.clone(),
40+
)
41+
};
42+
Ok(Some(new_array.into_array()))
43+
}
44+
}

encodings/sequence/src/array.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
use std::hash::Hash;
55
use std::ops::Range;
66

7-
use num_traits::One;
87
use num_traits::cast::FromPrimitive;
98
use vortex_array::ArrayBufferVisitor;
109
use vortex_array::ArrayChildVisitor;
@@ -15,6 +14,7 @@ use vortex_array::ExecutionCtx;
1514
use vortex_array::Precision;
1615
use vortex_array::ProstMetadata;
1716
use vortex_array::SerializeMetadata;
17+
use vortex_array::arrays::FilterVTable;
1818
use vortex_array::arrays::PrimitiveArray;
1919
use vortex_array::serde::ArrayChildren;
2020
use vortex_array::stats::ArrayStats;
@@ -45,11 +45,14 @@ use vortex_error::VortexResult;
4545
use vortex_error::vortex_bail;
4646
use vortex_error::vortex_ensure;
4747
use vortex_error::vortex_err;
48+
use vortex_mask::AllOr;
4849
use vortex_mask::Mask;
4950
use vortex_scalar::PValue;
5051
use vortex_scalar::Scalar;
5152
use vortex_scalar::ScalarValue;
5253
use vortex_vector::Vector;
54+
use vortex_vector::VectorMut;
55+
use vortex_vector::VectorMutOps;
5356
use vortex_vector::primitive::PVector;
5457

5558
vtable!(Sequence);
@@ -285,20 +288,48 @@ impl VTable for SequenceVTable {
285288
let base = array.base().cast::<P>();
286289
let multiplier = array.multiplier().cast::<P>();
287290

288-
let values = if multiplier == <P>::one() {
289-
BufferMut::from_iter(
290-
(0..array.len()).map(|i| base + <P>::from_usize(i).vortex_expect("must fit")),
291-
)
292-
} else {
293-
BufferMut::from_iter(
294-
(0..array.len())
295-
.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
296-
)
297-
};
298-
299-
PVector::<P>::new(values.freeze(), Mask::new_true(array.len())).into()
291+
execute_iter(base, multiplier, 0..array.len(), array.len()).into()
300292
}))
301293
}
294+
295+
fn execute_parent(
296+
array: &Self::Array,
297+
parent: &ArrayRef,
298+
_child_idx: usize,
299+
_ctx: &mut ExecutionCtx,
300+
) -> VortexResult<Option<Vector>> {
301+
// Special-case filtered execution.
302+
let Some(filter) = parent.as_opt::<FilterVTable>() else {
303+
return Ok(None);
304+
};
305+
306+
match filter.filter_mask().indices() {
307+
AllOr::All => Ok(None),
308+
AllOr::None => Ok(Some(VectorMut::with_capacity(array.dtype(), 0).freeze())),
309+
AllOr::Some(indices) => Ok(Some(match_each_native_ptype!(array.ptype(), |P| {
310+
let base = array.base().cast::<P>();
311+
let multiplier = array.multiplier().cast::<P>();
312+
execute_iter(base, multiplier, indices.iter().copied(), indices.len()).into()
313+
}))),
314+
}
315+
}
316+
}
317+
318+
fn execute_iter<P: NativePType, I: Iterator<Item = usize>>(
319+
base: P,
320+
multiplier: P,
321+
iter: I,
322+
len: usize,
323+
) -> PVector<P> {
324+
let values = if multiplier == <P>::one() {
325+
BufferMut::from_iter(iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit")))
326+
} else {
327+
BufferMut::from_iter(
328+
iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
329+
)
330+
};
331+
332+
PVector::<P>::new(values.freeze(), Mask::new_true(len))
302333
}
303334

304335
impl BaseArrayVTable<SequenceVTable> for SequenceVTable {

vortex-array/src/array/mod.rs

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ use vortex_error::vortex_err;
2323
use vortex_error::vortex_panic;
2424
use vortex_mask::Mask;
2525
use vortex_scalar::Scalar;
26-
use vortex_vector::Vector;
27-
use vortex_vector::VectorOps;
28-
use vortex_vector::vector_matches_dtype;
2926

3027
use crate::ArrayEq;
3128
use crate::ArrayHash;
@@ -52,7 +49,6 @@ use crate::compute::InvocationArgs;
5249
use crate::compute::IsConstantOpts;
5350
use crate::compute::Output;
5451
use crate::compute::is_constant_opts;
55-
use crate::executor::ExecutionCtx;
5652
use crate::expr::stats::Precision;
5753
use crate::expr::stats::Stat;
5854
use crate::expr::stats::StatsProviderExt;
@@ -200,11 +196,6 @@ pub trait Array:
200196
/// call.
201197
fn invoke(&self, compute_fn: &ComputeFn, args: &InvocationArgs)
202198
-> VortexResult<Option<Output>>;
203-
204-
/// Recursively execute an array using batch CPU execution.
205-
///
206-
/// To invoke the top-level execution, see [`crate::executor::VectorExecutor`].
207-
fn execute(&self, ctx: &mut ExecutionCtx) -> VortexResult<Vector>;
208199
}
209200

210201
impl Array for Arc<dyn Array> {
@@ -318,10 +309,6 @@ impl Array for Arc<dyn Array> {
318309
) -> VortexResult<Option<Output>> {
319310
self.as_ref().invoke(compute_fn, args)
320311
}
321-
322-
fn execute(&self, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
323-
self.as_ref().execute(ctx)
324-
}
325312
}
326313

327314
/// A reference counted pointer to a dynamic [`Array`] trait object.
@@ -679,25 +666,6 @@ impl<V: VTable> Array for ArrayAdapter<V> {
679666
) -> VortexResult<Option<Output>> {
680667
<V::ComputeVTable as ComputeVTable<V>>::invoke(&self.0, compute_fn, args)
681668
}
682-
683-
fn execute(&self, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
684-
let result = V::execute(&self.0, ctx)?;
685-
686-
if cfg!(debug_assertions) {
687-
vortex_ensure!(
688-
result.len() == self.len(),
689-
"Result length mismatch for {}",
690-
self.encoding_id()
691-
);
692-
vortex_ensure!(
693-
vector_matches_dtype(&result, self.dtype()),
694-
"Executed vector dtype mismatch for {}",
695-
self.encoding_id()
696-
);
697-
}
698-
699-
Ok(result)
700-
}
701669
}
702670

703671
impl<V: VTable> ArrayHash for ArrayAdapter<V> {

0 commit comments

Comments
 (0)