Skip to content

Commit 83fc7c5

Browse files
committed
Push down cast
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 2d28ba3 commit 83fc7c5

File tree

5 files changed

+168
-7
lines changed

5 files changed

+168
-7
lines changed

vortex-array/src/arrays/chunked/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::IntoArray;
1818
use crate::ToCanonical;
1919
use crate::arrays::ChunkedArray;
2020
use crate::arrays::PrimitiveArray;
21+
use crate::arrays::chunked::vtable::rules::PARENT_RULES;
2122
use crate::serde::ArrayChildren;
2223
use crate::validity::Validity;
2324
use crate::vtable;
@@ -31,6 +32,7 @@ mod array;
3132
mod canonical;
3233
mod compute;
3334
mod operations;
35+
mod rules;
3436
mod validity;
3537
mod visitor;
3638

@@ -166,4 +168,12 @@ impl VTable for ChunkedVTable {
166168
_ => None,
167169
})
168170
}
171+
172+
fn reduce_parent(
173+
array: &Self::Array,
174+
parent: &ArrayRef,
175+
child_idx: usize,
176+
) -> VortexResult<Option<ArrayRef>> {
177+
PARENT_RULES.evaluate(array, parent, child_idx)
178+
}
169179
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use itertools::Itertools;
5+
use vortex_error::VortexResult;
6+
7+
use crate::Array;
8+
use crate::ArrayRef;
9+
use crate::IntoArray;
10+
use crate::arrays::AnyScalarFn;
11+
use crate::arrays::ChunkedArray;
12+
use crate::arrays::ChunkedVTable;
13+
use crate::arrays::ConstantArray;
14+
use crate::arrays::ConstantVTable;
15+
use crate::arrays::ScalarFnArray;
16+
use crate::optimizer::rules::ArrayParentReduceRule;
17+
use crate::optimizer::rules::ParentRuleSet;
18+
19+
pub(super) const PARENT_RULES: ParentRuleSet<ChunkedVTable> =
20+
ParentRuleSet::new(&[ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule)]);
21+
22+
/// Push down any unary scalar function through chunked arrays.
23+
#[derive(Debug)]
24+
struct ChunkedUnaryScalarFnPushDownRule;
25+
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedUnaryScalarFnPushDownRule {
26+
type Parent = AnyScalarFn;
27+
28+
fn parent(&self) -> Self::Parent {
29+
AnyScalarFn
30+
}
31+
32+
fn reduce_parent(
33+
&self,
34+
array: &ChunkedArray,
35+
parent: &ScalarFnArray,
36+
_child_idx: usize,
37+
) -> VortexResult<Option<ArrayRef>> {
38+
if parent.children().len() != 1 {
39+
return Ok(None);
40+
}
41+
42+
let new_chunks: Vec<_> = array
43+
.chunks
44+
.iter()
45+
.map(|chunk| {
46+
ScalarFnArray::try_new(parent.scalar_fn().clone(), vec![chunk.clone()], chunk.len())
47+
.map(|a| a.into_array())
48+
})
49+
.try_collect()?;
50+
51+
Ok(Some(
52+
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
53+
))
54+
}
55+
}
56+
57+
/// Push down non-unary scalar functions through chunked arrays where other siblings are constant.
58+
#[derive(Debug)]
59+
struct ChunkedConstantScalarFnPushDownRule;
60+
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedConstantScalarFnPushDownRule {
61+
type Parent = AnyScalarFn;
62+
63+
fn parent(&self) -> Self::Parent {
64+
AnyScalarFn
65+
}
66+
67+
fn reduce_parent(
68+
&self,
69+
array: &ChunkedArray,
70+
parent: &ScalarFnArray,
71+
child_idx: usize,
72+
) -> VortexResult<Option<ArrayRef>> {
73+
for (idx, child) in parent.children().iter().enumerate() {
74+
if idx == child_idx {
75+
continue;
76+
}
77+
if !child.is::<ConstantVTable>() {
78+
return Ok(None);
79+
}
80+
}
81+
82+
let new_chunks: Vec<_> = array
83+
.chunks
84+
.iter()
85+
.map(|chunk| {
86+
let new_children: Vec<_> = parent
87+
.children()
88+
.iter()
89+
.enumerate()
90+
.map(|(idx, child)| {
91+
if idx == child_idx {
92+
chunk.clone()
93+
} else {
94+
ConstantArray::new(
95+
child.as_::<ConstantVTable>().scalar().clone(),
96+
chunk.len(),
97+
)
98+
.into_array()
99+
}
100+
})
101+
.collect();
102+
103+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())
104+
.map(|a| a.into_array())
105+
})
106+
.try_collect()?;
107+
108+
Ok(Some(
109+
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
110+
))
111+
}
112+
}

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
mod array;
55
pub use array::StructArray;
66
mod compute;
7-
mod rules;
87

98
mod vtable;
109
pub use vtable::StructVTable;

vortex-array/src/arrays/struct_/vtable/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::ArrayRef;
1717
use crate::EmptyMetadata;
1818
use crate::VectorExecutor;
1919
use crate::arrays::struct_::StructArray;
20-
use crate::arrays::struct_::rules::RULES;
20+
use crate::arrays::struct_::vtable::rules::PARENT_RULES;
2121
use crate::executor::ExecutionCtx;
2222
use crate::serde::ArrayChildren;
2323
use crate::validity::Validity;
@@ -30,6 +30,7 @@ use crate::vtable::ValidityVTableFromValidityHelper;
3030
mod array;
3131
mod canonical;
3232
mod operations;
33+
mod rules;
3334
mod validity;
3435
mod visitor;
3536

@@ -159,7 +160,7 @@ impl VTable for StructVTable {
159160
parent: &ArrayRef,
160161
child_idx: usize,
161162
) -> VortexResult<Option<ArrayRef>> {
162-
RULES.evaluate(array, parent, child_idx)
163+
PARENT_RULES.evaluate(array, parent, child_idx)
163164
}
164165
}
165166

vortex-array/src/arrays/struct_/rules.rs renamed to vortex-array/src/arrays/struct_/vtable/rules.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
use vortex_error::VortexResult;
55

6-
use crate::Array;
76
use crate::ArrayRef;
87
use crate::IntoArray;
98
use crate::arrays::ConstantArray;
109
use crate::arrays::ExactScalarFn;
11-
use crate::arrays::ScalarFnArrayExt;
1210
use crate::arrays::ScalarFnArrayView;
1311
use crate::arrays::StructArray;
1412
use crate::arrays::StructVTable;
13+
use crate::builtins::ArrayBuiltins;
14+
use crate::expr::Cast;
1515
use crate::expr::EmptyOptions;
1616
use crate::expr::GetItem;
1717
use crate::expr::Mask;
@@ -20,9 +20,48 @@ use crate::optimizer::rules::ParentRuleSet;
2020
use crate::validity::Validity;
2121
use crate::vtable::ValidityHelper;
2222

23-
pub(super) const RULES: ParentRuleSet<StructVTable> =
24-
ParentRuleSet::new(&[ParentRuleSet::lift(&StructGetItemRule)]);
23+
pub(super) const PARENT_RULES: ParentRuleSet<StructVTable> = ParentRuleSet::new(&[
24+
ParentRuleSet::lift(&StructCastPushDownRule),
25+
ParentRuleSet::lift(&StructGetItemRule),
26+
]);
2527

28+
/// Rule to push down cast into struct fields
29+
#[derive(Debug)]
30+
struct StructCastPushDownRule;
31+
impl ArrayParentReduceRule<StructVTable> for StructCastPushDownRule {
32+
type Parent = ExactScalarFn<Cast>;
33+
34+
fn parent(&self) -> Self::Parent {
35+
ExactScalarFn::from(&Cast)
36+
}
37+
38+
fn reduce_parent(
39+
&self,
40+
array: &StructArray,
41+
parent: ScalarFnArrayView<Cast>,
42+
_child_idx: usize,
43+
) -> VortexResult<Option<ArrayRef>> {
44+
let target_fields = parent.options.as_struct_fields();
45+
46+
let mut new_fields = Vec::with_capacity(target_fields.nfields());
47+
for (field_array, field_dtype) in array.fields.iter().zip(target_fields.fields()) {
48+
new_fields.push(field_array.cast(field_dtype)?)
49+
}
50+
51+
let new_struct = unsafe {
52+
StructArray::new_unchecked(
53+
new_fields,
54+
target_fields.clone(),
55+
array.len(),
56+
array.validity().clone(),
57+
)
58+
};
59+
60+
Ok(Some(new_struct.into_array()))
61+
}
62+
}
63+
64+
/// Rule to flatten get_item from struct by field name
2665
#[derive(Debug)]
2766
pub(crate) struct StructGetItemRule;
2867
impl ArrayParentReduceRule<StructVTable> for StructGetItemRule {

0 commit comments

Comments
 (0)