Skip to content

Commit c008e35

Browse files
committed
CastFn
Signed-off-by: Nicholas Gates <[email protected]>
1 parent aae9639 commit c008e35

File tree

9 files changed

+154
-12
lines changed

9 files changed

+154
-12
lines changed

vortex-array/src/arrays/scalar_fn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ mod array;
55
mod metadata;
66
mod vtable;
77

8-
pub use vtable::ScalarFnArrayExt;
8+
pub use vtable::*;

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ mod operations;
77
mod validity;
88
mod visitor;
99

10+
use std::marker::PhantomData;
11+
use std::ops::Deref;
12+
1013
use itertools::Itertools;
1114
use vortex_buffer::BufferHandle;
1215
use vortex_dtype::DType;
@@ -24,6 +27,8 @@ use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
2427
use crate::execution::ExecutionCtx;
2528
use crate::expr::functions;
2629
use crate::expr::functions::scalar::ScalarFn;
30+
use crate::optimizer::rules::MatchKey;
31+
use crate::optimizer::rules::Matcher;
2732
use crate::serde::ArrayChildren;
2833
use crate::vtable;
2934
use crate::vtable::ArrayId;
@@ -168,3 +173,75 @@ pub trait ScalarFnArrayExt: functions::VTable {
168173
}
169174
}
170175
impl<V: functions::VTable> ScalarFnArrayExt for V {}
176+
177+
/// A matcher that matches any scalar function expression.
178+
#[derive(Debug)]
179+
pub struct AnyScalarFn;
180+
impl Matcher for AnyScalarFn {
181+
type View<'a> = &'a ScalarFnArray;
182+
183+
fn key(&self) -> MatchKey {
184+
MatchKey::Any
185+
}
186+
187+
fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
188+
array.as_opt::<ScalarFnVTable>()
189+
}
190+
}
191+
192+
/// A matcher that matches a specific scalar function expression.
193+
#[derive(Debug)]
194+
pub struct ExactScalarFn<F: functions::VTable> {
195+
id: ArrayId,
196+
_phantom: PhantomData<F>,
197+
}
198+
199+
impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
200+
fn from(value: &'static F) -> Self {
201+
Self {
202+
id: value.id(),
203+
_phantom: PhantomData,
204+
}
205+
}
206+
}
207+
208+
impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
209+
type View<'a> = ScalarFnArrayView<'a, F>;
210+
211+
fn key(&self) -> MatchKey {
212+
MatchKey::Array(self.id.clone())
213+
}
214+
215+
fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
216+
let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
217+
let scalar_fn_vtable = scalar_fn_array
218+
.scalar_fn
219+
.vtable()
220+
.as_any()
221+
.downcast_ref::<F>()?;
222+
let scalar_fn_options = scalar_fn_array
223+
.scalar_fn
224+
.options()
225+
.as_any()
226+
.downcast_ref::<F::Options>()?;
227+
Some(ScalarFnArrayView {
228+
array,
229+
vtable: scalar_fn_vtable,
230+
options: scalar_fn_options,
231+
})
232+
}
233+
}
234+
235+
pub struct ScalarFnArrayView<'a, F: functions::VTable> {
236+
array: &'a ArrayRef,
237+
pub vtable: &'a F,
238+
pub options: &'a F::Options,
239+
}
240+
241+
impl<F: functions::VTable> Deref for ScalarFnArrayView<'_, F> {
242+
type Target = ArrayRef;
243+
244+
fn deref(&self) -> &Self::Target {
245+
self.array
246+
}
247+
}

vortex-array/src/arrays/scalar_fn/vtable/visitor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ impl VisitorVTable<ScalarFnVTable> for ScalarFnVTable {
1111
fn visit_buffers(_array: &ScalarFnArray, _visitor: &mut dyn ArrayBufferVisitor) {}
1212

1313
fn visit_children(array: &ScalarFnArray, visitor: &mut dyn ArrayChildVisitor) {
14-
for (idx, child) in array.children().iter().enumerate() {
14+
for (idx, child) in array.children.iter().enumerate() {
1515
let name = array.scalar_fn.signature().arg_name(idx);
1616
visitor.visit_child(name.as_ref(), child.as_ref())
1717
}

vortex-array/src/expr/functions/scalar.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ pub struct ScalarFnOptions<'a> {
196196
pub(crate) options: &'a dyn Any,
197197
}
198198

199-
impl ScalarFnOptions<'_> {
199+
impl<'a> ScalarFnOptions<'a> {
200200
/// Get the options as a `dyn Any`.
201-
pub fn as_any(&self) -> &dyn Any {
201+
pub fn as_any(&self) -> &'a dyn Any {
202202
self.options
203203
}
204204
}

vortex-array/src/expr/functions/vtable.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ pub enum NullHandling {
167167

168168
/// An object-safe vtable for scalar functions that dispatches to the non-object-safe vtable.
169169
pub(crate) trait DynScalarFnVTable: 'static + Send + Sync {
170+
fn as_any(&self) -> &dyn Any;
171+
170172
fn id(&self) -> FunctionId;
171173

172174
fn options_serialize(&self, options: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
@@ -202,6 +204,10 @@ pub(crate) trait DynScalarFnVTable: 'static + Send + Sync {
202204
#[repr(transparent)]
203205
pub struct ScalarFnVTableAdapter<V>(V);
204206
impl<V: VTable> DynScalarFnVTable for ScalarFnVTableAdapter<V> {
207+
fn as_any(&self) -> &dyn Any {
208+
&self.0
209+
}
210+
205211
fn id(&self) -> FunctionId {
206212
V::id(&self.0)
207213
}
@@ -308,6 +314,10 @@ impl ScalarFnVTable {
308314
self.0.id()
309315
}
310316

317+
pub fn as_any(&self) -> &dyn Any {
318+
self.0.deref().as_any()
319+
}
320+
311321
pub fn deserialize(&self, bytes: &[u8]) -> VortexResult<ScalarFn> {
312322
let options = self.0.options_deserialize(bytes)?;
313323
// SAFETY: options were created by this vtable.
Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,59 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
// TODO(ngates): can't yet write the rule matching on scalar function vtable
4+
use vortex_error::VortexResult;
5+
6+
use crate::Array;
7+
use crate::ArrayRef;
8+
use crate::ArrayVisitor;
9+
use crate::arrays::ExactScalarFn;
10+
use crate::arrays::ScalarFnArrayView;
11+
use crate::optimizer::rules::ArrayReduceRule;
12+
use crate::scalar_fns::cast::CastFn;
13+
514
#[derive(Debug)]
6-
#[allow(dead_code)]
7-
pub struct CastReduce;
15+
pub(crate) struct CastArrayReduce;
16+
17+
impl ArrayReduceRule<ExactScalarFn<CastFn>> for CastArrayReduce {
18+
fn matcher(&self) -> ExactScalarFn<CastFn> {
19+
ExactScalarFn::from(&CastFn)
20+
}
21+
22+
fn reduce(&self, array: ScalarFnArrayView<'_, CastFn>) -> VortexResult<Option<ArrayRef>> {
23+
let target_dtype = array.options;
24+
25+
// If the array is already of the target dtype, then return the input node as-is.
26+
if array.dtype() == target_dtype {
27+
return Ok(Some(array.children()[0].clone()));
28+
}
29+
30+
Ok(None)
31+
}
32+
}
33+
34+
#[cfg(test)]
35+
mod test {
36+
use vortex_error::VortexResult;
37+
38+
use super::CastArrayReduce;
39+
use crate::Array;
40+
use crate::array::IntoArray;
41+
use crate::arrays::ConstantArray;
42+
use crate::arrays::ConstantVTable;
43+
use crate::optimizer::ArrayOptimizer;
44+
use crate::scalar_fns::BuiltinScalarFns;
45+
46+
#[test]
47+
fn test_same_dtype() -> VortexResult<()> {
48+
let mut optimizer = ArrayOptimizer::default();
49+
optimizer.register_reduce_rule(CastArrayReduce);
50+
51+
let array = ConstantArray::new(true, 10).into_array();
52+
let cast_same_dtype = array.cast(array.dtype().clone())?;
53+
54+
let optimized = optimizer.optimize_array(cast_same_dtype)?;
55+
assert!(optimized.is::<ConstantVTable>());
56+
57+
Ok(())
58+
}
59+
}

vortex-array/src/scalar_fns/cast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
mod array;
4+
pub(crate) mod array;
55

66
use prost::Message;
77
use vortex_dtype::DType;

vortex-array/src/scalar_fns/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ use crate::ArrayRef;
1717
use crate::arrays::ScalarFnArrayExt;
1818
use crate::expr::Expression;
1919
use crate::expr::ScalarFnExprExt;
20-
use crate::scalar_fns::cast::CastFn;
2120

22-
mod cast;
21+
pub mod cast;
2322

2423
/// A collection of built-in scalar functions that can be applied to expressions or arrays.
2524
pub trait BuiltinScalarFns: Sized {
@@ -29,12 +28,12 @@ pub trait BuiltinScalarFns: Sized {
2928

3029
impl BuiltinScalarFns for Expression {
3130
fn cast(&self, dtype: DType) -> VortexResult<Expression> {
32-
CastFn.try_new_expr(dtype, [self.clone()])
31+
cast::CastFn.try_new_expr(dtype, [self.clone()])
3332
}
3433
}
3534

3635
impl BuiltinScalarFns for ArrayRef {
3736
fn cast(&self, dtype: DType) -> VortexResult<Self> {
38-
CastFn.try_new_array(self.len(), dtype, [self.clone()])
37+
cast::CastFn.try_new_array(self.len(), dtype, [self.clone()])
3938
}
4039
}

vortex-array/src/session/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::arrays::StructVTable;
2323
use crate::arrays::VarBinVTable;
2424
use crate::arrays::VarBinViewVTable;
2525
use crate::optimizer::ArrayOptimizer;
26+
use crate::scalar_fns::cast::array::CastArrayReduce;
2627
use crate::vtable::ArrayVTable;
2728
use crate::vtable::ArrayVTableExt;
2829

@@ -97,6 +98,9 @@ impl Default for ArraySession {
9798
optimizer.register_parent_rule(PrimitiveMaskedValidityRule);
9899
optimizer.register_parent_rule(DecimalMaskedValidityRule);
99100

101+
// Scalar function rules
102+
optimizer.register_reduce_rule(CastArrayReduce);
103+
100104
session
101105
}
102106
}

0 commit comments

Comments
 (0)