Skip to content

Commit d90bbda

Browse files
authored
Move optimizer rules onto Array vtable (#5712)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 63931f0 commit d90bbda

File tree

40 files changed

+938
-787
lines changed

40 files changed

+938
-787
lines changed

encodings/runend/src/array.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use vortex_scalar::PValue;
4545
use crate::compress::runend_decode_bools;
4646
use crate::compress::runend_decode_primitive;
4747
use crate::compress::runend_encode;
48+
use crate::rules::RULES;
4849

4950
vtable!(RunEnd);
5051

@@ -132,6 +133,14 @@ impl VTable for RunEndVTable {
132133

133134
Ok(())
134135
}
136+
137+
fn reduce_parent(
138+
array: &Self::Array,
139+
parent: &ArrayRef,
140+
child_idx: usize,
141+
) -> VortexResult<Option<ArrayRef>> {
142+
RULES.evaluate(array, parent, child_idx)
143+
}
135144
}
136145

137146
#[derive(Clone, Debug)]

encodings/runend/src/lib.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,14 @@ pub mod _benchmarking {
2424
use vortex_array::ArrayBufferVisitor;
2525
use vortex_array::ArrayChildVisitor;
2626
use vortex_array::Canonical;
27-
use vortex_array::session::ArraySession;
2827
use vortex_array::session::ArraySessionExt;
2928
use vortex_array::vtable::ArrayVTableExt;
3029
use vortex_array::vtable::EncodeVTable;
3130
use vortex_array::vtable::VisitorVTable;
3231
use vortex_error::VortexResult;
33-
use vortex_session::SessionExt;
3432
use vortex_session::VortexSession;
3533

3634
use crate::compress::runend_encode;
37-
use crate::rules::RunEndScalarFnRule;
3835

3936
impl EncodeVTable<RunEndVTable> for RunEndVTable {
4037
fn encode(
@@ -69,10 +66,6 @@ impl VisitorVTable<RunEndVTable> for RunEndVTable {
6966
/// Initialize run-end encoding in the given session.
7067
pub fn initialize(session: &mut VortexSession) {
7168
session.arrays().register(RunEndVTable.as_vtable());
72-
session
73-
.get_mut::<ArraySession>()
74-
.optimizer_mut()
75-
.register_parent_rule(RunEndScalarFnRule);
7669
}
7770

7871
#[cfg(test)]

encodings/runend/src/rules.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ use vortex_array::arrays::ConstantArray;
88
use vortex_array::arrays::ConstantVTable;
99
use vortex_array::arrays::ScalarFnArray;
1010
use vortex_array::optimizer::rules::ArrayParentReduceRule;
11-
use vortex_array::optimizer::rules::Exact;
11+
use vortex_array::optimizer::rules::ParentRuleSet;
1212
use vortex_dtype::DType;
1313
use vortex_error::VortexResult;
1414

1515
use crate::RunEndArray;
1616
use crate::RunEndVTable;
1717

18+
pub(super) const RULES: ParentRuleSet<RunEndVTable> =
19+
ParentRuleSet::new(&[ParentRuleSet::lift(&RunEndScalarFnRule)]);
20+
1821
/// A rule to push down scalar functions through run-end encoding into the values array.
1922
///
2023
/// This only works if all other children of the scalar function array are constants.
2124
#[derive(Debug)]
2225
pub(crate) struct RunEndScalarFnRule;
2326

24-
impl ArrayParentReduceRule<Exact<RunEndVTable>, AnyScalarFn> for RunEndScalarFnRule {
25-
fn child(&self) -> Exact<RunEndVTable> {
26-
Exact::from(&RunEndVTable)
27-
}
27+
impl ArrayParentReduceRule<RunEndVTable> for RunEndScalarFnRule {
28+
type Parent = AnyScalarFn;
2829

2930
fn parent(&self) -> AnyScalarFn {
3031
AnyScalarFn

vortex-array/src/array/mod.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use crate::hash;
5454
use crate::kernel::BindCtx;
5555
use crate::kernel::KernelRef;
5656
use crate::kernel::ValidateKernel;
57+
use crate::optimizer::ArrayOptimizer;
5758
use crate::stats::StatsSetRef;
5859
use crate::vtable::ArrayId;
5960
use crate::vtable::ArrayVTable;
@@ -196,6 +197,12 @@ pub trait Array:
196197

197198
/// Invoke the batch execution function for the array to produce a canonical vector.
198199
fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult<KernelRef>;
200+
201+
/// Reduce the array to a more simple representation, if possible.
202+
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;
203+
204+
/// Attempt to perform a reduction of the parent of this array.
205+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
199206
}
200207

201208
impl Array for Arc<dyn Array> {
@@ -309,6 +316,14 @@ impl Array for Arc<dyn Array> {
309316
fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult<KernelRef> {
310317
self.as_ref().bind_kernel(ctx)
311318
}
319+
320+
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
321+
self.as_ref().reduce()
322+
}
323+
324+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
325+
self.as_ref().reduce_parent(parent, child_idx)
326+
}
312327
}
313328

314329
/// A reference counted pointer to a dynamic [`Array`] trait object.
@@ -512,11 +527,15 @@ impl<V: VTable> Array for ArrayAdapter<V> {
512527

513528
fn filter(&self, mask: Mask) -> VortexResult<ArrayRef> {
514529
vortex_ensure!(self.len() == mask.len(), "Filter mask length mismatch");
515-
Ok(FilterArray::new(self.to_array(), mask).into_array())
530+
FilterArray::new(self.to_array(), mask)
531+
.into_array()
532+
.optimize()
516533
}
517534

518535
fn take(&self, indices: ArrayRef) -> VortexResult<ArrayRef> {
519-
Ok(DictArray::try_new(indices, self.to_array())?.into_array())
536+
DictArray::try_new(indices, self.to_array())?
537+
.into_array()
538+
.optimize()
520539
}
521540

522541
fn scalar_at(&self, index: usize) -> Scalar {
@@ -658,6 +677,42 @@ impl<V: VTable> Array for ArrayAdapter<V> {
658677
Ok(kernel)
659678
}
660679
}
680+
681+
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
682+
let Some(reduced) = V::reduce(&self.0)? else {
683+
return Ok(None);
684+
};
685+
vortex_ensure!(reduced.len() == self.len(), "Reduced array length mismatch");
686+
vortex_ensure!(
687+
reduced.dtype() == self.dtype(),
688+
"Reduced array dtype mismatch"
689+
);
690+
Ok(Some(reduced))
691+
}
692+
693+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
694+
#[cfg(debug_assertions)]
695+
vortex_ensure!(
696+
Arc::as_ptr(&parent.children()[child_idx]) == self,
697+
"Parent array's child at index {} does not match self",
698+
child_idx
699+
);
700+
701+
let Some(reduced) = V::reduce_parent(&self.0, parent, child_idx)? else {
702+
return Ok(None);
703+
};
704+
705+
vortex_ensure!(
706+
reduced.len() == parent.len(),
707+
"Reduced array length mismatch"
708+
);
709+
vortex_ensure!(
710+
reduced.dtype() == parent.dtype(),
711+
"Reduced array dtype mismatch"
712+
);
713+
714+
Ok(Some(reduced))
715+
}
661716
}
662717

663718
impl<V: VTable> ArrayHash for ArrayAdapter<V> {

vortex-array/src/arrays/bool/vtable/mod.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
2727
mod array;
2828
mod canonical;
2929
mod operations;
30-
pub mod operator;
30+
pub mod rules;
3131
mod validity;
3232
mod visitor;
3333

34-
pub use operator::BoolMaskedValidityRule;
34+
pub use rules::BoolMaskedValidityRule;
3535

36+
use crate::arrays::bool::vtable::rules::RULES;
3637
use crate::kernel::KernelRef;
3738
use crate::kernel::ready;
3839
use crate::vtable::ArrayId;
@@ -112,12 +113,6 @@ impl VTable for BoolVTable {
112113
BoolArray::try_new(bits, validity)
113114
}
114115

115-
fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult<KernelRef> {
116-
Ok(ready(
117-
BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into(),
118-
))
119-
}
120-
121116
fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
122117
vortex_ensure!(
123118
children.len() <= 1,
@@ -133,6 +128,20 @@ impl VTable for BoolVTable {
133128

134129
Ok(())
135130
}
131+
132+
fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult<KernelRef> {
133+
Ok(ready(
134+
BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into(),
135+
))
136+
}
137+
138+
fn reduce_parent(
139+
array: &Self::Array,
140+
parent: &ArrayRef,
141+
child_idx: usize,
142+
) -> VortexResult<Option<ArrayRef>> {
143+
RULES.evaluate(array, parent, child_idx)
144+
}
136145
}
137146

138147
#[derive(Debug)]

vortex-array/src/arrays/bool/vtable/operator.rs renamed to vortex-array/src/arrays/bool/vtable/rules.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@ use crate::arrays::MaskedArray;
1111
use crate::arrays::MaskedVTable;
1212
use crate::optimizer::rules::ArrayParentReduceRule;
1313
use crate::optimizer::rules::Exact;
14+
use crate::optimizer::rules::ParentRuleSet;
1415
use crate::vtable::ValidityHelper;
1516

17+
pub(super) const RULES: ParentRuleSet<BoolVTable> =
18+
ParentRuleSet::new(&[ParentRuleSet::lift(&BoolMaskedValidityRule)]);
19+
1620
/// Rule to push down validity masking from MaskedArray parent into BoolArray child.
1721
///
1822
/// When a BoolArray is wrapped by a MaskedArray, this rule merges the mask's validity
1923
/// with the BoolArray's existing validity, eliminating the need for the MaskedArray wrapper.
2024
#[derive(Default, Debug)]
2125
pub struct BoolMaskedValidityRule;
2226

23-
impl ArrayParentReduceRule<Exact<BoolVTable>, Exact<MaskedVTable>> for BoolMaskedValidityRule {
24-
fn child(&self) -> Exact<BoolVTable> {
25-
Exact::from(&BoolVTable)
26-
}
27+
impl ArrayParentReduceRule<BoolVTable> for BoolMaskedValidityRule {
28+
type Parent = Exact<MaskedVTable>;
2729

2830
fn parent(&self) -> Exact<MaskedVTable> {
2931
Exact::from(&MaskedVTable)

vortex-array/src/arrays/chunked/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ use vortex_vector::VectorMut;
1616
use vortex_vector::VectorMutOps;
1717

1818
use crate::ArrayRef;
19+
use crate::Canonical;
1920
use crate::EmptyMetadata;
21+
use crate::IntoArray;
2022
use crate::ToCanonical;
2123
use crate::arrays::ChunkedArray;
2224
use crate::arrays::PrimitiveArray;
@@ -175,6 +177,14 @@ impl VTable for ChunkedVTable {
175177
dtype: array.dtype.clone(),
176178
}))
177179
}
180+
181+
fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
182+
Ok(match array.chunks.len() {
183+
0 => Some(Canonical::empty(array.dtype()).into_array()),
184+
1 => Some(array.chunks[0].clone()),
185+
_ => None,
186+
})
187+
}
178188
}
179189

180190
#[derive(Debug)]

vortex-array/src/arrays/decimal/vtable/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
3232
mod array;
3333
mod canonical;
3434
mod operations;
35-
pub mod operator;
35+
pub mod rules;
3636
mod validity;
3737
mod visitor;
3838

39-
pub use operator::DecimalMaskedValidityRule;
39+
pub use rules::DecimalMaskedValidityRule;
4040

41+
use crate::arrays::decimal::vtable::rules::RULES;
4142
use crate::kernel::KernelRef;
4243
use crate::kernel::kernel;
4344
use crate::vtable::ArrayId;
@@ -181,6 +182,14 @@ impl VTable for DecimalVTable {
181182
})
182183
})
183184
}
185+
186+
fn reduce_parent(
187+
array: &Self::Array,
188+
parent: &ArrayRef,
189+
child_idx: usize,
190+
) -> VortexResult<Option<ArrayRef>> {
191+
RULES.evaluate(array, parent, child_idx)
192+
}
184193
}
185194

186195
#[derive(Debug)]

vortex-array/src/arrays/decimal/vtable/operator.rs renamed to vortex-array/src/arrays/decimal/vtable/rules.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ use crate::arrays::MaskedArray;
1212
use crate::arrays::MaskedVTable;
1313
use crate::optimizer::rules::ArrayParentReduceRule;
1414
use crate::optimizer::rules::Exact;
15+
use crate::optimizer::rules::ParentRuleSet;
1516
use crate::vtable::ValidityHelper;
1617

18+
pub(super) static RULES: ParentRuleSet<DecimalVTable> =
19+
ParentRuleSet::new(&[ParentRuleSet::lift(&DecimalMaskedValidityRule)]);
20+
1721
/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
1822
///
1923
/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
2024
/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
2125
#[derive(Default, Debug)]
2226
pub struct DecimalMaskedValidityRule;
2327

24-
impl ArrayParentReduceRule<Exact<DecimalVTable>, Exact<MaskedVTable>>
25-
for DecimalMaskedValidityRule
26-
{
27-
fn child(&self) -> Exact<DecimalVTable> {
28-
Exact::from(&DecimalVTable)
29-
}
28+
impl ArrayParentReduceRule<DecimalVTable> for DecimalMaskedValidityRule {
29+
type Parent = Exact<MaskedVTable>;
3030

3131
fn parent(&self) -> Exact<MaskedVTable> {
3232
Exact::from(&MaskedVTable)

vortex-array/src/arrays/primitive/vtable/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
2828
mod array;
2929
mod canonical;
3030
mod operations;
31-
pub mod operator;
31+
pub mod rules;
3232
mod validity;
3333
mod visitor;
3434

35-
pub use operator::PrimitiveMaskedValidityRule;
35+
pub use rules::PrimitiveMaskedValidityRule;
3636

37+
use crate::arrays::primitive::vtable::rules::RULES;
3738
use crate::kernel::KernelRef;
3839
use crate::kernel::ready;
3940
use crate::vtable::ArrayId;
@@ -141,6 +142,14 @@ impl VTable for PrimitiveVTable {
141142

142143
Ok(())
143144
}
145+
146+
fn reduce_parent(
147+
array: &Self::Array,
148+
parent: &ArrayRef,
149+
child_idx: usize,
150+
) -> VortexResult<Option<ArrayRef>> {
151+
RULES.evaluate(array, parent, child_idx)
152+
}
144153
}
145154

146155
#[derive(Debug)]

0 commit comments

Comments
 (0)