Skip to content

Commit ab3bf26

Browse files
committed
u
Signed-off-by: Joe Isaacs <[email protected]>
1 parent a740b60 commit ab3bf26

File tree

11 files changed

+91
-138
lines changed

11 files changed

+91
-138
lines changed

vortex-array/src/array/operator.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ pub trait ArrayOperator: 'static + Send + Sync {
2626
/// If the array's implementation returns an invalid vector (wrong length, wrong type, etc.).
2727
fn execute_batch(&self, selection: &Mask, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;
2828

29-
/// Optimize the array by running the optimization rules.
30-
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;
31-
32-
/// Optimize the array by pushing down a parent array.
33-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
34-
3529
/// Returns the array as a pipeline node, if supported.
3630
fn as_pipelined(&self) -> Option<&dyn PipelinedNode>;
3731

@@ -48,14 +42,6 @@ impl ArrayOperator for Arc<dyn Array> {
4842
self.as_ref().execute_batch(selection, ctx)
4943
}
5044

51-
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
52-
self.as_ref().reduce()
53-
}
54-
55-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
56-
self.as_ref().reduce_parent(parent, child_idx)
57-
}
58-
5945
fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
6046
self.as_ref().as_pipelined()
6147
}
@@ -96,14 +82,6 @@ impl<V: VTable> ArrayOperator for ArrayAdapter<V> {
9682
Ok(vector)
9783
}
9884

99-
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
100-
<V::OperatorVTable as OperatorVTable<V>>::reduce(&self.0)
101-
}
102-
103-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
104-
<V::OperatorVTable as OperatorVTable<V>>::reduce_parent(&self.0, parent, child_idx)
105-
}
106-
10785
fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
10886
<V::OperatorVTable as OperatorVTable<V>>::pipeline_node(&self.0)
10987
}

vortex-array/src/array/transform/optimizer.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ impl ArrayOptimizer {
7979
// Now try to apply parent rules to each optimized child in the context of this array
8080
// Use the optimized_children list directly instead of re-fetching from array.children()
8181
// let mut transformed_children = Vec::with_capacity(optimized_children.len());
82-
let rules_applied = false;
8382

8483
for (idx, child) in optimized_children.iter().enumerate() {
8584
let child_id = child.encoding_id();

vortex-array/src/arrays/decimal/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub use array::DecimalArray;
77
mod compute;
88

99
mod vtable;
10-
pub use vtable::{DecimalEncoding, DecimalVTable};
10+
pub use vtable::{DecimalEncoding, DecimalMaskedValidityRule, DecimalVTable};
1111

1212
mod utils;
1313
pub use utils::*;

vortex-array/src/arrays/decimal/vtable/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ use crate::{
1717
mod array;
1818
mod canonical;
1919
mod operations;
20-
mod operator;
20+
pub mod operator;
2121
mod validity;
2222
mod visitor;
2323

24+
pub use operator::DecimalMaskedValidityRule;
25+
2426
vtable!(Decimal);
2527

2628
// The type of the values can be determined by looking at the type info...right?
@@ -42,7 +44,7 @@ impl VTable for DecimalVTable {
4244
type VisitorVTable = Self;
4345
type ComputeVTable = NotSupported;
4446
type EncodeVTable = NotSupported;
45-
type OperatorVTable = NotSupported;
47+
type OperatorVTable = Self;
4648

4749
fn id(_encoding: &Self::Encoding) -> EncodingId {
4850
EncodingId::new_ref("vortex.decimal")

vortex-array/src/arrays/decimal/vtable/operator.rs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use vortex_dtype::{PrecisionScale, match_each_decimal_value_type};
66
use vortex_error::VortexResult;
77
use vortex_vector::decimal::DVector;
88

9-
use crate::arrays::{DecimalArray, DecimalVTable, MaskedVTable};
9+
use crate::array::transform::{ArrayParentReduceRule, ArrayRuleContext};
10+
use crate::arrays::{DecimalArray, DecimalVTable, MaskedArray, MaskedVTable};
1011
use crate::execution::{BatchKernelRef, BindCtx, kernel};
1112
use crate::vtable::{OperatorVTable, ValidityHelper};
1213
use crate::{ArrayRef, IntoArray};
@@ -36,30 +37,37 @@ impl OperatorVTable<DecimalVTable> for DecimalVTable {
3637
}))
3738
})
3839
}
40+
}
41+
42+
/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
43+
///
44+
/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
45+
/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
46+
pub struct DecimalMaskedValidityRule;
3947

48+
impl ArrayParentReduceRule<DecimalVTable, MaskedVTable> for DecimalMaskedValidityRule {
4049
fn reduce_parent(
50+
&self,
4151
array: &DecimalArray,
42-
parent: &ArrayRef,
52+
parent: &MaskedArray,
4353
_child_idx: usize,
54+
_ctx: &ArrayRuleContext,
4455
) -> VortexResult<Option<ArrayRef>> {
45-
// Push-down masking of `validity` from the parent `MaskedArray`.
46-
if let Some(masked) = parent.as_opt::<MaskedVTable>() {
47-
let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
48-
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
49-
// were upheld are still upheld.
50-
unsafe {
51-
DecimalArray::new_unchecked(
52-
array.buffer::<D>(),
53-
array.decimal_dtype(),
54-
array.validity().clone().and(masked.validity().clone()),
55-
)
56-
}
57-
.into_array()
58-
});
59-
60-
return Ok(Some(masked_array));
61-
}
56+
// Merge the parent's validity mask into the child's validity
57+
// TODO(joe): make this lazy
58+
let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
59+
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
60+
// were upheld are still upheld.
61+
unsafe {
62+
DecimalArray::new_unchecked(
63+
array.buffer::<D>(),
64+
array.decimal_dtype(),
65+
array.validity().clone().and(parent.validity().clone()),
66+
)
67+
}
68+
.into_array()
69+
});
6270

63-
Ok(None)
71+
Ok(Some(masked_array))
6472
}
6573
}

vortex-array/src/arrays/primitive/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod compute;
88
pub use compute::{IS_CONST_LANE_WIDTH, compute_is_constant};
99

1010
mod vtable;
11-
pub use vtable::{PrimitiveEncoding, PrimitiveVTable};
11+
pub use vtable::{PrimitiveEncoding, PrimitiveMaskedValidityRule, PrimitiveVTable};
1212

1313
mod native_value;
1414
pub use native_value::NativeValue;

vortex-array/src/arrays/primitive/vtable/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ use crate::{EmptyMetadata, EncodingId, EncodingRef, vtable};
1414
mod array;
1515
mod canonical;
1616
mod operations;
17-
mod operator;
17+
pub mod operator;
1818
mod validity;
1919
mod visitor;
2020

21+
pub use operator::PrimitiveMaskedValidityRule;
22+
2123
vtable!(Primitive);
2224

2325
impl VTable for PrimitiveVTable {

vortex-array/src/arrays/primitive/vtable/operator.rs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use vortex_dtype::match_each_native_ptype;
77
use vortex_error::VortexResult;
88
use vortex_vector::primitive::PVector;
99

10-
use crate::arrays::{MaskedVTable, PrimitiveArray, PrimitiveVTable};
10+
use crate::array::transform::{ArrayParentReduceRule, ArrayRuleContext};
11+
use crate::arrays::{MaskedArray, MaskedVTable, PrimitiveArray, PrimitiveVTable};
1112
use crate::execution::{BatchKernelRef, BindCtx, kernel};
1213
use crate::vtable::{OperatorVTable, ValidityHelper};
1314
use crate::{ArrayRef, IntoArray};
@@ -35,29 +36,36 @@ impl OperatorVTable<PrimitiveVTable> for PrimitiveVTable {
3536
}))
3637
})
3738
}
39+
}
40+
41+
/// Rule to push down validity masking from MaskedArray parent into PrimitiveArray child.
42+
///
43+
/// When a PrimitiveArray is wrapped by a MaskedArray, this rule merges the mask's validity
44+
/// with the PrimitiveArray's existing validity, eliminating the need for the MaskedArray wrapper.
45+
pub struct PrimitiveMaskedValidityRule;
3846

47+
impl ArrayParentReduceRule<PrimitiveVTable, MaskedVTable> for PrimitiveMaskedValidityRule {
3948
fn reduce_parent(
49+
&self,
4050
array: &PrimitiveArray,
41-
parent: &ArrayRef,
51+
parent: &MaskedArray,
4252
_child_idx: usize,
53+
_ctx: &ArrayRuleContext,
4354
) -> VortexResult<Option<ArrayRef>> {
44-
// Push-down masking of `validity` from the parent `MaskedArray`.
45-
if let Some(masked) = parent.as_opt::<MaskedVTable>() {
46-
let masked_array = match_each_native_ptype!(array.ptype(), |T| {
47-
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
48-
// were upheld are still upheld.
49-
unsafe {
50-
PrimitiveArray::new_unchecked(
51-
Buffer::<T>::from_byte_buffer(array.byte_buffer().clone()),
52-
array.validity().clone().and(masked.validity().clone()),
53-
)
54-
}
55-
.into_array()
56-
});
57-
58-
return Ok(Some(masked_array));
59-
}
60-
61-
Ok(None)
55+
// Merge the parent's validity mask into the child's validity
56+
// TODO(joe): make this lazy
57+
let masked_array = match_each_native_ptype!(array.ptype(), |T| {
58+
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
59+
// were upheld are still upheld.
60+
unsafe {
61+
PrimitiveArray::new_unchecked(
62+
Buffer::<T>::from_byte_buffer(array.byte_buffer().clone()),
63+
array.validity().clone().and(parent.validity().clone()),
64+
)
65+
}
66+
.into_array()
67+
});
68+
69+
Ok(Some(masked_array))
6270
}
6371
}

vortex-array/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,18 @@ impl Default for ArraySession {
179179
arrays::BoolMaskedValidityRule,
180180
);
181181

182+
session.register_parent_rule::<arrays::PrimitiveVTable, arrays::MaskedVTable, _>(
183+
&PrimitiveEncoding,
184+
&MaskedEncoding,
185+
arrays::PrimitiveMaskedValidityRule,
186+
);
187+
188+
session.register_parent_rule::<arrays::DecimalVTable, arrays::MaskedVTable, _>(
189+
&DecimalEncoding,
190+
&MaskedEncoding,
191+
arrays::DecimalMaskedValidityRule,
192+
);
193+
182194
session.register_parent_rule::<arrays::StructVTable, arrays::ExprVTable, _>(
183195
&StructEncoding,
184196
&arrays::ExprEncoding,

vortex-array/src/optimizer.rs

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,10 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::sync::Arc;
5-
64
use vortex_error::VortexResult;
75

6+
use crate::ArrayRef;
87
use crate::vtable::VTable;
9-
use crate::{Array, ArrayRef};
10-
11-
impl dyn Array + '_ {
12-
/// Optimize this array by applying optimization rules recursively to its children in a single
13-
/// bottom-up pass.
14-
pub fn optimize(&self) -> VortexResult<ArrayRef> {
15-
let slf = self.to_array();
16-
let children = self.children();
17-
18-
let mut new_children = Vec::with_capacity(children.len());
19-
let mut children_modified = false;
20-
for (idx, child) in children.iter().enumerate() {
21-
let child = child.optimize()?;
22-
23-
// Check if the child can reduce us (its parent), and if so bail early.
24-
if let Some(reduced) = child.reduce_parent(&slf, idx)? {
25-
return Ok(reduced);
26-
}
27-
28-
if !Arc::ptr_eq(&child, &children[idx]) {
29-
children_modified = true;
30-
}
31-
new_children.push(child);
32-
}
33-
34-
if children_modified {
35-
return self.with_children(&new_children);
36-
}
37-
38-
Ok(slf)
39-
}
40-
}
418

429
/// An optimizer rule that tries to reduce/replace a parent array where the implementer is a
4310
/// child array in the `CHILD_IDX` position of the parent array.
@@ -63,9 +30,11 @@ mod tests {
6330
use vortex_dtype::PTypeDowncast;
6431
use vortex_vector::VectorOps;
6532

66-
use crate::IntoArray;
6733
use crate::arrays::{BoolArray, MaskedArray, PrimitiveArray};
34+
use crate::expr::session::ExprSession;
35+
use crate::expr::transform::ExprOptimizer;
6836
use crate::validity::Validity;
37+
use crate::{ArraySession, IntoArray};
6938

7039
#[test]
7140
fn test_masked_pushdown() {
@@ -78,8 +47,16 @@ mod tests {
7847
)
7948
.unwrap();
8049

81-
let result = masked.optimize().unwrap();
82-
assert_eq!(masked.dtype(), result.dtype());
50+
let masked_dtype = masked.dtype().clone();
51+
52+
// Use the new ArrayOptimizer via ArraySession
53+
let array_session = ArraySession::default();
54+
let expr_session = ExprSession::default();
55+
let expr_optimizer = ExprOptimizer::new(&expr_session);
56+
let optimizer = array_session.optimizer(expr_optimizer);
57+
58+
let result = optimizer.optimize_array(masked.into_array()).unwrap();
59+
assert_eq!(&masked_dtype, result.dtype());
8360
assert!(result.dtype().is_nullable());
8461

8562
let vector = result.execute().unwrap().into_primitive().into_u32();

0 commit comments

Comments
 (0)