Skip to content

Commit 208d8bc

Browse files
authored
Test implementation of cast using ScalarFn (#5586)
* Defines a CastFn implementation of ScalarFn * Updates the optimizer rules to support matcher instances (vs marker structs) so we can match on runtime array IDs * Implement CastReduce rule that for now just removes the cast if the child array already has the same target DType. --------- Signed-off-by: Nicholas Gates <[email protected]>
1 parent 841c7cd commit 208d8bc

File tree

26 files changed

+727
-241
lines changed

26 files changed

+727
-241
lines changed

Cargo.lock

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/arrays/bool/vtable/operator.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use crate::execution::BatchKernelRef;
1515
use crate::execution::BindCtx;
1616
use crate::execution::kernel;
1717
use crate::optimizer::rules::ArrayParentReduceRule;
18+
use crate::optimizer::rules::Exact;
1819
use crate::vtable::OperatorVTable;
1920
use crate::vtable::ValidityHelper;
2021

@@ -47,7 +48,15 @@ impl OperatorVTable<BoolVTable> for BoolVTable {
4748
#[derive(Default, Debug)]
4849
pub struct BoolMaskedValidityRule;
4950

50-
impl ArrayParentReduceRule<BoolVTable, MaskedVTable> for BoolMaskedValidityRule {
51+
impl ArrayParentReduceRule<Exact<BoolVTable>, Exact<MaskedVTable>> for BoolMaskedValidityRule {
52+
fn child(&self) -> Exact<BoolVTable> {
53+
Exact::from(&BoolVTable)
54+
}
55+
56+
fn parent(&self) -> Exact<MaskedVTable> {
57+
Exact::from(&MaskedVTable)
58+
}
59+
5160
fn reduce_parent(
5261
&self,
5362
array: &BoolArray,

vortex-array/src/arrays/decimal/vtable/operator.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::execution::BatchKernelRef;
1717
use crate::execution::BindCtx;
1818
use crate::execution::kernel;
1919
use crate::optimizer::rules::ArrayParentReduceRule;
20+
use crate::optimizer::rules::Exact;
2021
use crate::vtable::OperatorVTable;
2122
use crate::vtable::ValidityHelper;
2223

@@ -54,7 +55,17 @@ impl OperatorVTable<DecimalVTable> for DecimalVTable {
5455
#[derive(Default, Debug)]
5556
pub struct DecimalMaskedValidityRule;
5657

57-
impl ArrayParentReduceRule<DecimalVTable, MaskedVTable> for DecimalMaskedValidityRule {
58+
impl ArrayParentReduceRule<Exact<DecimalVTable>, Exact<MaskedVTable>>
59+
for DecimalMaskedValidityRule
60+
{
61+
fn child(&self) -> Exact<DecimalVTable> {
62+
Exact::from(&DecimalVTable)
63+
}
64+
65+
fn parent(&self) -> Exact<MaskedVTable> {
66+
Exact::from(&MaskedVTable)
67+
}
68+
5869
fn reduce_parent(
5970
&self,
6071
array: &DecimalArray,

vortex-array/src/arrays/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub use listview::*;
5050
pub use masked::*;
5151
pub use null::*;
5252
pub use primitive::*;
53+
pub use scalar_fn::*;
5354
pub use struct_::*;
5455
pub use varbin::*;
5556
pub use varbinview::*;

vortex-array/src/arrays/primitive/vtable/operator.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::execution::BatchKernelRef;
1717
use crate::execution::BindCtx;
1818
use crate::execution::kernel;
1919
use crate::optimizer::rules::ArrayParentReduceRule;
20+
use crate::optimizer::rules::Exact;
2021
use crate::vtable::OperatorVTable;
2122
use crate::vtable::ValidityHelper;
2223

@@ -52,7 +53,17 @@ impl OperatorVTable<PrimitiveVTable> for PrimitiveVTable {
5253
#[derive(Default, Debug)]
5354
pub struct PrimitiveMaskedValidityRule;
5455

55-
impl ArrayParentReduceRule<PrimitiveVTable, MaskedVTable> for PrimitiveMaskedValidityRule {
56+
impl ArrayParentReduceRule<Exact<PrimitiveVTable>, Exact<MaskedVTable>>
57+
for PrimitiveMaskedValidityRule
58+
{
59+
fn child(&self) -> Exact<PrimitiveVTable> {
60+
Exact::from(&PrimitiveVTable)
61+
}
62+
63+
fn parent(&self) -> Exact<MaskedVTable> {
64+
Exact::from(&MaskedVTable)
65+
}
66+
5667
fn reduce_parent(
5768
&self,
5869
array: &PrimitiveArray,

vortex-array/src/arrays/scalar_fn/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
mod array;
55
mod metadata;
66
mod vtable;
7+
8+
pub use vtable::*;

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,28 @@ mod operations;
77
mod validity;
88
mod visitor;
99

10+
use std::marker::PhantomData;
11+
use std::ops::Deref;
12+
1013
use itertools::Itertools;
1114
use vortex_buffer::BufferHandle;
1215
use vortex_dtype::DType;
1316
use vortex_error::VortexExpect;
1417
use vortex_error::VortexResult;
1518
use vortex_error::vortex_bail;
19+
use vortex_error::vortex_ensure;
1620
use vortex_vector::Vector;
1721

1822
use crate::Array;
23+
use crate::ArrayRef;
24+
use crate::IntoArray;
1925
use crate::arrays::scalar_fn::array::ScalarFnArray;
2026
use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
2127
use crate::execution::ExecutionCtx;
2228
use crate::expr::functions;
29+
use crate::expr::functions::scalar::ScalarFn;
30+
use crate::optimizer::rules::MatchKey;
31+
use crate::optimizer::rules::Matcher;
2332
use crate::serde::ArrayChildren;
2433
use crate::vtable;
2534
use crate::vtable::ArrayId;
@@ -127,3 +136,112 @@ impl VTable for ScalarFnVTable {
127136
.vortex_expect("Vector inputs should return vector outputs"))
128137
}
129138
}
139+
140+
/// Array factory functions for scalar functions.
141+
pub trait ScalarFnArrayExt: functions::VTable {
142+
fn try_new_array(
143+
&'static self,
144+
len: usize,
145+
options: Self::Options,
146+
children: impl Into<Vec<ArrayRef>>,
147+
) -> VortexResult<ArrayRef> {
148+
let scalar_fn = ScalarFn::new_static(self, options);
149+
150+
let children = children.into();
151+
vortex_ensure!(
152+
children.iter().all(|c| c.len() == len),
153+
"All child arrays must have the same length as the scalar function array"
154+
);
155+
156+
let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
157+
let dtype = scalar_fn.return_dtype(&child_dtypes)?;
158+
159+
let array_vtable: ArrayVTable = ScalarFnVTable {
160+
vtable: scalar_fn.vtable().clone(),
161+
}
162+
.into_vtable();
163+
164+
Ok(ScalarFnArray {
165+
vtable: array_vtable,
166+
scalar_fn,
167+
dtype,
168+
len,
169+
children,
170+
stats: Default::default(),
171+
}
172+
.into_array())
173+
}
174+
}
175+
impl<V: functions::VTable> ScalarFnArrayExt for V {}
176+
177+
/// A matcher that matches any scalar function expression.
178+
#[derive(Debug)]
179+
pub struct AnyScalarFn;
180+
impl Matcher for AnyScalarFn {
181+
type View<'a> = &'a ScalarFnArray;
182+
183+
fn key(&self) -> MatchKey {
184+
MatchKey::Any
185+
}
186+
187+
fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
188+
array.as_opt::<ScalarFnVTable>()
189+
}
190+
}
191+
192+
/// A matcher that matches a specific scalar function expression.
193+
#[derive(Debug)]
194+
pub struct ExactScalarFn<F: functions::VTable> {
195+
id: ArrayId,
196+
_phantom: PhantomData<F>,
197+
}
198+
199+
impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
200+
fn from(value: &'static F) -> Self {
201+
Self {
202+
id: value.id(),
203+
_phantom: PhantomData,
204+
}
205+
}
206+
}
207+
208+
impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
209+
type View<'a> = ScalarFnArrayView<'a, F>;
210+
211+
fn key(&self) -> MatchKey {
212+
MatchKey::Array(self.id.clone())
213+
}
214+
215+
fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
216+
let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
217+
let scalar_fn_vtable = scalar_fn_array
218+
.scalar_fn
219+
.vtable()
220+
.as_any()
221+
.downcast_ref::<F>()?;
222+
let scalar_fn_options = scalar_fn_array
223+
.scalar_fn
224+
.options()
225+
.as_any()
226+
.downcast_ref::<F::Options>()?;
227+
Some(ScalarFnArrayView {
228+
array,
229+
vtable: scalar_fn_vtable,
230+
options: scalar_fn_options,
231+
})
232+
}
233+
}
234+
235+
pub struct ScalarFnArrayView<'a, F: functions::VTable> {
236+
array: &'a ArrayRef,
237+
pub vtable: &'a F,
238+
pub options: &'a F::Options,
239+
}
240+
241+
impl<F: functions::VTable> Deref for ScalarFnArrayView<'_, F> {
242+
type Target = ArrayRef;
243+
244+
fn deref(&self) -> &Self::Target {
245+
self.array
246+
}
247+
}

vortex-array/src/arrays/scalar_fn/vtable/visitor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ impl VisitorVTable<ScalarFnVTable> for ScalarFnVTable {
1111
fn visit_buffers(_array: &ScalarFnArray, _visitor: &mut dyn ArrayBufferVisitor) {}
1212

1313
fn visit_children(array: &ScalarFnArray, visitor: &mut dyn ArrayChildVisitor) {
14-
for (idx, child) in array.children().iter().enumerate() {
14+
for (idx, child) in array.children.iter().enumerate() {
1515
let name = array.scalar_fn.signature().arg_name(idx);
1616
visitor.visit_child(name.as_ref(), child.as_ref())
1717
}

0 commit comments

Comments
 (0)