Skip to content

Commit d12f356

Browse files
authored
Push down cast (#5803)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 2d28ba3 commit d12f356

File tree

5 files changed

+177
-6
lines changed

5 files changed

+177
-6
lines changed

vortex-array/src/arrays/chunked/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::IntoArray;
1818
use crate::ToCanonical;
1919
use crate::arrays::ChunkedArray;
2020
use crate::arrays::PrimitiveArray;
21+
use crate::arrays::chunked::vtable::rules::PARENT_RULES;
2122
use crate::serde::ArrayChildren;
2223
use crate::validity::Validity;
2324
use crate::vtable;
@@ -31,6 +32,7 @@ mod array;
3132
mod canonical;
3233
mod compute;
3334
mod operations;
35+
mod rules;
3436
mod validity;
3537
mod visitor;
3638

@@ -166,4 +168,12 @@ impl VTable for ChunkedVTable {
166168
_ => None,
167169
})
168170
}
171+
172+
fn reduce_parent(
173+
array: &Self::Array,
174+
parent: &ArrayRef,
175+
child_idx: usize,
176+
) -> VortexResult<Option<ArrayRef>> {
177+
PARENT_RULES.evaluate(array, parent, child_idx)
178+
}
169179
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use itertools::Itertools;
5+
use vortex_error::VortexResult;
6+
7+
use crate::Array;
8+
use crate::ArrayRef;
9+
use crate::IntoArray;
10+
use crate::arrays::AnyScalarFn;
11+
use crate::arrays::ChunkedArray;
12+
use crate::arrays::ChunkedVTable;
13+
use crate::arrays::ConstantArray;
14+
use crate::arrays::ConstantVTable;
15+
use crate::arrays::ScalarFnArray;
16+
use crate::optimizer::ArrayOptimizer;
17+
use crate::optimizer::rules::ArrayParentReduceRule;
18+
use crate::optimizer::rules::ParentRuleSet;
19+
20+
pub(super) const PARENT_RULES: ParentRuleSet<ChunkedVTable> = ParentRuleSet::new(&[
21+
ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule),
22+
ParentRuleSet::lift(&ChunkedConstantScalarFnPushDownRule),
23+
]);
24+
25+
/// Push down any unary scalar function through chunked arrays.
26+
#[derive(Debug)]
27+
struct ChunkedUnaryScalarFnPushDownRule;
28+
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedUnaryScalarFnPushDownRule {
29+
type Parent = AnyScalarFn;
30+
31+
fn parent(&self) -> Self::Parent {
32+
AnyScalarFn
33+
}
34+
35+
fn reduce_parent(
36+
&self,
37+
array: &ChunkedArray,
38+
parent: &ScalarFnArray,
39+
_child_idx: usize,
40+
) -> VortexResult<Option<ArrayRef>> {
41+
if parent.children().len() != 1 {
42+
return Ok(None);
43+
}
44+
45+
let new_chunks: Vec<_> = array
46+
.chunks
47+
.iter()
48+
.map(|chunk| {
49+
ScalarFnArray::try_new(
50+
parent.scalar_fn().clone(),
51+
vec![chunk.clone()],
52+
chunk.len(),
53+
)?
54+
.into_array()
55+
.optimize()
56+
})
57+
.try_collect()?;
58+
59+
Ok(Some(
60+
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
61+
))
62+
}
63+
}
64+
65+
/// Push down non-unary scalar functions through chunked arrays where other siblings are constant.
66+
#[derive(Debug)]
67+
struct ChunkedConstantScalarFnPushDownRule;
68+
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedConstantScalarFnPushDownRule {
69+
type Parent = AnyScalarFn;
70+
71+
fn parent(&self) -> Self::Parent {
72+
AnyScalarFn
73+
}
74+
75+
fn reduce_parent(
76+
&self,
77+
array: &ChunkedArray,
78+
parent: &ScalarFnArray,
79+
child_idx: usize,
80+
) -> VortexResult<Option<ArrayRef>> {
81+
for (idx, child) in parent.children().iter().enumerate() {
82+
if idx == child_idx {
83+
continue;
84+
}
85+
if !child.is::<ConstantVTable>() {
86+
return Ok(None);
87+
}
88+
}
89+
90+
let new_chunks: Vec<_> = array
91+
.chunks
92+
.iter()
93+
.map(|chunk| {
94+
let new_children: Vec<_> = parent
95+
.children()
96+
.iter()
97+
.enumerate()
98+
.map(|(idx, child)| {
99+
if idx == child_idx {
100+
chunk.clone()
101+
} else {
102+
ConstantArray::new(
103+
child.as_::<ConstantVTable>().scalar().clone(),
104+
chunk.len(),
105+
)
106+
.into_array()
107+
}
108+
})
109+
.collect();
110+
111+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())?
112+
.into_array()
113+
.optimize()
114+
})
115+
.try_collect()?;
116+
117+
Ok(Some(
118+
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
119+
))
120+
}
121+
}

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
mod array;
55
pub use array::StructArray;
66
mod compute;
7-
mod rules;
87

98
mod vtable;
109
pub use vtable::StructVTable;

vortex-array/src/arrays/struct_/vtable/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::ArrayRef;
1717
use crate::EmptyMetadata;
1818
use crate::VectorExecutor;
1919
use crate::arrays::struct_::StructArray;
20-
use crate::arrays::struct_::rules::RULES;
20+
use crate::arrays::struct_::vtable::rules::PARENT_RULES;
2121
use crate::executor::ExecutionCtx;
2222
use crate::serde::ArrayChildren;
2323
use crate::validity::Validity;
@@ -30,6 +30,7 @@ use crate::vtable::ValidityVTableFromValidityHelper;
3030
mod array;
3131
mod canonical;
3232
mod operations;
33+
mod rules;
3334
mod validity;
3435
mod visitor;
3536

@@ -159,7 +160,7 @@ impl VTable for StructVTable {
159160
parent: &ArrayRef,
160161
child_idx: usize,
161162
) -> VortexResult<Option<ArrayRef>> {
162-
RULES.evaluate(array, parent, child_idx)
163+
PARENT_RULES.evaluate(array, parent, child_idx)
163164
}
164165
}
165166

vortex-array/src/arrays/struct_/rules.rs renamed to vortex-array/src/arrays/struct_/vtable/rules.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
use vortex_error::VortexResult;
55

6-
use crate::Array;
76
use crate::ArrayRef;
87
use crate::IntoArray;
98
use crate::arrays::ConstantArray;
@@ -12,6 +11,8 @@ use crate::arrays::ScalarFnArrayExt;
1211
use crate::arrays::ScalarFnArrayView;
1312
use crate::arrays::StructArray;
1413
use crate::arrays::StructVTable;
14+
use crate::builtins::ArrayBuiltins;
15+
use crate::expr::Cast;
1516
use crate::expr::EmptyOptions;
1617
use crate::expr::GetItem;
1718
use crate::expr::Mask;
@@ -20,9 +21,48 @@ use crate::optimizer::rules::ParentRuleSet;
2021
use crate::validity::Validity;
2122
use crate::vtable::ValidityHelper;
2223

23-
pub(super) const RULES: ParentRuleSet<StructVTable> =
24-
ParentRuleSet::new(&[ParentRuleSet::lift(&StructGetItemRule)]);
24+
pub(super) const PARENT_RULES: ParentRuleSet<StructVTable> = ParentRuleSet::new(&[
25+
ParentRuleSet::lift(&StructCastPushDownRule),
26+
ParentRuleSet::lift(&StructGetItemRule),
27+
]);
2528

29+
/// Rule to push down cast into struct fields
30+
#[derive(Debug)]
31+
struct StructCastPushDownRule;
32+
impl ArrayParentReduceRule<StructVTable> for StructCastPushDownRule {
33+
type Parent = ExactScalarFn<Cast>;
34+
35+
fn parent(&self) -> Self::Parent {
36+
ExactScalarFn::from(&Cast)
37+
}
38+
39+
fn reduce_parent(
40+
&self,
41+
array: &StructArray,
42+
parent: ScalarFnArrayView<Cast>,
43+
_child_idx: usize,
44+
) -> VortexResult<Option<ArrayRef>> {
45+
let target_fields = parent.options.as_struct_fields();
46+
47+
let mut new_fields = Vec::with_capacity(target_fields.nfields());
48+
for (field_array, field_dtype) in array.fields.iter().zip(target_fields.fields()) {
49+
new_fields.push(field_array.cast(field_dtype)?)
50+
}
51+
52+
let new_struct = unsafe {
53+
StructArray::new_unchecked(
54+
new_fields,
55+
target_fields.clone(),
56+
array.len(),
57+
array.validity().clone(),
58+
)
59+
};
60+
61+
Ok(Some(new_struct.into_array()))
62+
}
63+
}
64+
65+
/// Rule to flatten get_item from struct by field name
2666
#[derive(Debug)]
2767
pub(crate) struct StructGetItemRule;
2868
impl ArrayParentReduceRule<StructVTable> for StructGetItemRule {

0 commit comments

Comments
 (0)