Skip to content

Commit 7c0b68f

Browse files
committed
Move rules onto array vtable so we can optimize during construction without a session
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 96dce46 commit 7c0b68f

File tree

29 files changed

+344
-618
lines changed

29 files changed

+344
-618
lines changed

encodings/runend/src/array.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use vortex_scalar::PValue;
4545
use crate::compress::runend_decode_bools;
4646
use crate::compress::runend_decode_primitive;
4747
use crate::compress::runend_encode;
48+
use crate::rules::RULES;
4849

4950
vtable!(RunEnd);
5051

@@ -132,6 +133,14 @@ impl VTable for RunEndVTable {
132133

133134
Ok(())
134135
}
136+
137+
fn reduce_parent(
138+
array: &Self::Array,
139+
parent: &ArrayRef,
140+
child_idx: usize,
141+
) -> VortexResult<Option<ArrayRef>> {
142+
RULES.evaluate(array, parent, child_idx)
143+
}
135144
}
136145

137146
#[derive(Clone, Debug)]

encodings/runend/src/lib.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,14 @@ pub mod _benchmarking {
2424
use vortex_array::ArrayBufferVisitor;
2525
use vortex_array::ArrayChildVisitor;
2626
use vortex_array::Canonical;
27-
use vortex_array::session::ArraySession;
2827
use vortex_array::session::ArraySessionExt;
2928
use vortex_array::vtable::ArrayVTableExt;
3029
use vortex_array::vtable::EncodeVTable;
3130
use vortex_array::vtable::VisitorVTable;
3231
use vortex_error::VortexResult;
33-
use vortex_session::SessionExt;
3432
use vortex_session::VortexSession;
3533

3634
use crate::compress::runend_encode;
37-
use crate::rules::RunEndScalarFnRule;
3835

3936
impl EncodeVTable<RunEndVTable> for RunEndVTable {
4037
fn encode(
@@ -69,10 +66,6 @@ impl VisitorVTable<RunEndVTable> for RunEndVTable {
6966
/// Initialize run-end encoding in the given session.
7067
pub fn initialize(session: &mut VortexSession) {
7168
session.arrays().register(RunEndVTable.as_vtable());
72-
session
73-
.get_mut::<ArraySession>()
74-
.optimizer_mut()
75-
.register_parent_rule(RunEndScalarFnRule);
7669
}
7770

7871
#[cfg(test)]

encodings/runend/src/rules.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ use vortex_array::arrays::ConstantArray;
88
use vortex_array::arrays::ConstantVTable;
99
use vortex_array::arrays::ScalarFnArray;
1010
use vortex_array::optimizer::rules::ArrayParentReduceRule;
11-
use vortex_array::optimizer::rules::Exact;
11+
use vortex_array::optimizer::rules::ParentRuleSet;
1212
use vortex_dtype::DType;
1313
use vortex_error::VortexResult;
1414

1515
use crate::RunEndArray;
1616
use crate::RunEndVTable;
1717

18+
pub(super) const RULES: ParentRuleSet<RunEndVTable> =
19+
ParentRuleSet::new(&[ParentRuleSet::lift(&RunEndScalarFnRule)]);
20+
1821
/// A rule to push down scalar functions through run-end encoding into the values array.
1922
///
2023
/// This only works if all other children of the scalar function array are constants.
2124
#[derive(Debug)]
2225
pub(crate) struct RunEndScalarFnRule;
2326

24-
impl ArrayParentReduceRule<Exact<RunEndVTable>, AnyScalarFn> for RunEndScalarFnRule {
25-
fn child(&self) -> Exact<RunEndVTable> {
26-
Exact::from(&RunEndVTable)
27-
}
27+
impl ArrayParentReduceRule<RunEndVTable> for RunEndScalarFnRule {
28+
type Parent = AnyScalarFn;
2829

2930
fn parent(&self) -> AnyScalarFn {
3031
AnyScalarFn

vortex-array/src/array/mod.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ pub trait Array:
196196

197197
/// Invoke the batch execution function for the array to produce a canonical vector.
198198
fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult<KernelRef>;
199+
200+
/// Reduce the array to a more simple representation, if possible.
201+
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;
202+
203+
/// Attempt to perform a reduction of the parent of this array.
204+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
199205
}
200206

201207
impl Array for Arc<dyn Array> {
@@ -309,6 +315,14 @@ impl Array for Arc<dyn Array> {
309315
fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult<KernelRef> {
310316
self.as_ref().bind_kernel(ctx)
311317
}
318+
319+
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
320+
self.as_ref().reduce()
321+
}
322+
323+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
324+
self.as_ref().reduce_parent(parent, child_idx)
325+
}
312326
}
313327

314328
/// A reference counted pointer to a dynamic [`Array`] trait object.
@@ -658,6 +672,42 @@ impl<V: VTable> Array for ArrayAdapter<V> {
658672
Ok(kernel)
659673
}
660674
}
675+
676+
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
677+
let Some(reduced) = V::reduce(&self.0)? else {
678+
return Ok(None);
679+
};
680+
vortex_ensure!(reduced.len() == self.len(), "Reduced array length mismatch");
681+
vortex_ensure!(
682+
reduced.dtype() == self.dtype(),
683+
"Reduced array dtype mismatch"
684+
);
685+
Ok(Some(reduced))
686+
}
687+
688+
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
689+
#[cfg(debug_assertions)]
690+
vortex_ensure!(
691+
Arc::as_ptr(&parent.children()[child_idx]) == self,
692+
"Parent array's child at index {} does not match self",
693+
child_idx
694+
);
695+
696+
let Some(reduced) = V::reduce_parent(&self.0, parent, child_idx)? else {
697+
return Ok(None);
698+
};
699+
700+
vortex_ensure!(
701+
reduced.len() == parent.len(),
702+
"Reduced array length mismatch"
703+
);
704+
vortex_ensure!(
705+
reduced.dtype() == parent.dtype(),
706+
"Reduced array dtype mismatch"
707+
);
708+
709+
Ok(Some(reduced))
710+
}
661711
}
662712

663713
impl<V: VTable> ArrayHash for ArrayAdapter<V> {

vortex-array/src/arrays/bool/vtable/mod.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
2727
mod array;
2828
mod canonical;
2929
mod operations;
30-
pub mod operator;
30+
pub mod rules;
3131
mod validity;
3232
mod visitor;
3333

34-
pub use operator::BoolMaskedValidityRule;
34+
pub use rules::BoolMaskedValidityRule;
3535

36+
use crate::arrays::bool::vtable::rules::RULES;
3637
use crate::kernel::KernelRef;
3738
use crate::kernel::ready;
3839
use crate::vtable::ArrayId;
@@ -112,12 +113,6 @@ impl VTable for BoolVTable {
112113
BoolArray::try_new(bits, validity)
113114
}
114115

115-
fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult<KernelRef> {
116-
Ok(ready(
117-
BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into(),
118-
))
119-
}
120-
121116
fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
122117
vortex_ensure!(
123118
children.len() <= 1,
@@ -133,6 +128,20 @@ impl VTable for BoolVTable {
133128

134129
Ok(())
135130
}
131+
132+
fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult<KernelRef> {
133+
Ok(ready(
134+
BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into(),
135+
))
136+
}
137+
138+
fn reduce_parent(
139+
array: &Self::Array,
140+
parent: &ArrayRef,
141+
child_idx: usize,
142+
) -> VortexResult<Option<ArrayRef>> {
143+
RULES.evaluate(array, parent, child_idx)
144+
}
136145
}
137146

138147
#[derive(Debug)]

vortex-array/src/arrays/bool/vtable/operator.rs renamed to vortex-array/src/arrays/bool/vtable/rules.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@ use crate::arrays::MaskedArray;
1111
use crate::arrays::MaskedVTable;
1212
use crate::optimizer::rules::ArrayParentReduceRule;
1313
use crate::optimizer::rules::Exact;
14+
use crate::optimizer::rules::ParentRuleSet;
1415
use crate::vtable::ValidityHelper;
1516

17+
pub(super) const RULES: ParentRuleSet<BoolVTable> =
18+
ParentRuleSet::new(&[ParentRuleSet::lift(&BoolMaskedValidityRule)]);
19+
1620
/// Rule to push down validity masking from MaskedArray parent into BoolArray child.
1721
///
1822
/// When a BoolArray is wrapped by a MaskedArray, this rule merges the mask's validity
1923
/// with the BoolArray's existing validity, eliminating the need for the MaskedArray wrapper.
2024
#[derive(Default, Debug)]
2125
pub struct BoolMaskedValidityRule;
2226

23-
impl ArrayParentReduceRule<Exact<BoolVTable>, Exact<MaskedVTable>> for BoolMaskedValidityRule {
24-
fn child(&self) -> Exact<BoolVTable> {
25-
Exact::from(&BoolVTable)
26-
}
27+
impl ArrayParentReduceRule<BoolVTable> for BoolMaskedValidityRule {
28+
type Parent = Exact<MaskedVTable>;
2729

2830
fn parent(&self) -> Exact<MaskedVTable> {
2931
Exact::from(&MaskedVTable)

vortex-array/src/arrays/chunked/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ use vortex_vector::VectorMut;
1616
use vortex_vector::VectorMutOps;
1717

1818
use crate::ArrayRef;
19+
use crate::Canonical;
1920
use crate::EmptyMetadata;
21+
use crate::IntoArray;
2022
use crate::ToCanonical;
2123
use crate::arrays::ChunkedArray;
2224
use crate::arrays::PrimitiveArray;
@@ -175,6 +177,14 @@ impl VTable for ChunkedVTable {
175177
dtype: array.dtype.clone(),
176178
}))
177179
}
180+
181+
fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
182+
Ok(match array.chunks.len() {
183+
0 => Some(Canonical::empty(array.dtype()).into_array()),
184+
1 => Some(array.chunks[0].clone()),
185+
_ => None,
186+
})
187+
}
178188
}
179189

180190
#[derive(Debug)]

vortex-array/src/arrays/decimal/vtable/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
3232
mod array;
3333
mod canonical;
3434
mod operations;
35-
pub mod operator;
35+
pub mod rules;
3636
mod validity;
3737
mod visitor;
3838

39-
pub use operator::DecimalMaskedValidityRule;
39+
pub use rules::DecimalMaskedValidityRule;
4040

41+
use crate::arrays::decimal::vtable::rules::RULES;
4142
use crate::kernel::KernelRef;
4243
use crate::kernel::kernel;
4344
use crate::vtable::ArrayId;
@@ -181,6 +182,14 @@ impl VTable for DecimalVTable {
181182
})
182183
})
183184
}
185+
186+
fn reduce_parent(
187+
array: &Self::Array,
188+
parent: &ArrayRef,
189+
child_idx: usize,
190+
) -> VortexResult<Option<ArrayRef>> {
191+
RULES.evaluate(array, parent, child_idx)
192+
}
184193
}
185194

186195
#[derive(Debug)]

vortex-array/src/arrays/decimal/vtable/operator.rs renamed to vortex-array/src/arrays/decimal/vtable/rules.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ use crate::arrays::MaskedArray;
1212
use crate::arrays::MaskedVTable;
1313
use crate::optimizer::rules::ArrayParentReduceRule;
1414
use crate::optimizer::rules::Exact;
15+
use crate::optimizer::rules::ParentRuleSet;
1516
use crate::vtable::ValidityHelper;
1617

18+
pub(super) static RULES: ParentRuleSet<DecimalVTable> =
19+
ParentRuleSet::new(&[ParentRuleSet::lift(&DecimalMaskedValidityRule)]);
20+
1721
/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
1822
///
1923
/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
2024
/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
2125
#[derive(Default, Debug)]
2226
pub struct DecimalMaskedValidityRule;
2327

24-
impl ArrayParentReduceRule<Exact<DecimalVTable>, Exact<MaskedVTable>>
25-
for DecimalMaskedValidityRule
26-
{
27-
fn child(&self) -> Exact<DecimalVTable> {
28-
Exact::from(&DecimalVTable)
29-
}
28+
impl ArrayParentReduceRule<DecimalVTable> for DecimalMaskedValidityRule {
29+
type Parent = Exact<MaskedVTable>;
3030

3131
fn parent(&self) -> Exact<MaskedVTable> {
3232
Exact::from(&MaskedVTable)

vortex-array/src/arrays/primitive/vtable/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ use crate::vtable::ValidityVTableFromValidityHelper;
2828
mod array;
2929
mod canonical;
3030
mod operations;
31-
pub mod operator;
31+
pub mod rules;
3232
mod validity;
3333
mod visitor;
3434

35-
pub use operator::PrimitiveMaskedValidityRule;
35+
pub use rules::PrimitiveMaskedValidityRule;
3636

37+
use crate::arrays::primitive::vtable::rules::RULES;
3738
use crate::kernel::KernelRef;
3839
use crate::kernel::ready;
3940
use crate::vtable::ArrayId;
@@ -141,6 +142,14 @@ impl VTable for PrimitiveVTable {
141142

142143
Ok(())
143144
}
145+
146+
fn reduce_parent(
147+
array: &Self::Array,
148+
parent: &ArrayRef,
149+
child_idx: usize,
150+
) -> VortexResult<Option<ArrayRef>> {
151+
RULES.evaluate(array, parent, child_idx)
152+
}
144153
}
145154

146155
#[derive(Debug)]

0 commit comments

Comments
 (0)