Skip to content

Commit ac6073b

Browse files
authored
Execute ScalarFn arrays (#5611)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent decb562 commit ac6073b

File tree

4 files changed

+58
-5
lines changed

4 files changed

+58
-5
lines changed

vortex-array/src/arrays/scalar_fn/array.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_dtype::DType;
5+
use vortex_error::VortexResult;
6+
use vortex_error::vortex_ensure;
57

8+
use crate::Array;
69
use crate::ArrayRef;
10+
use crate::arrays::ScalarFnVTable;
711
use crate::expr::functions::scalar::ScalarFn;
812
use crate::stats::ArrayStats;
913
use crate::vtable::ArrayVTable;
14+
use crate::vtable::ArrayVTableExt;
1015

1116
#[derive(Clone, Debug)]
1217
pub struct ScalarFnArray {
@@ -18,3 +23,25 @@ pub struct ScalarFnArray {
1823
pub(super) children: Vec<ArrayRef>,
1924
pub(super) stats: ArrayStats,
2025
}
26+
27+
impl ScalarFnArray {
28+
/// Create a new ScalarFnArray from a scalar function and its children.
29+
pub fn try_new(scalar_fn: ScalarFn, children: Vec<ArrayRef>, len: usize) -> VortexResult<Self> {
30+
let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
31+
let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
32+
33+
vortex_ensure!(
34+
children.iter().all(|c| c.len() == len),
35+
"ScalarFnArray must have children equal to the array length"
36+
);
37+
38+
Ok(Self {
39+
vtable: ScalarFnVTable::new(scalar_fn.vtable().clone()).into_vtable(),
40+
scalar_fn,
41+
dtype,
42+
len,
43+
children,
44+
stats: Default::default(),
45+
})
46+
}
47+
}

vortex-array/src/arrays/scalar_fn/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ mod array;
55
mod metadata;
66
mod vtable;
77

8+
pub use array::*;
89
pub use vtable::*;

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ pub struct ScalarFnVTable {
5353
vtable: functions::ScalarFnVTable,
5454
}
5555

56+
impl ScalarFnVTable {
57+
pub fn new(vtable: functions::ScalarFnVTable) -> Self {
58+
Self { vtable }
59+
}
60+
}
61+
5662
impl VTable for ScalarFnVTable {
5763
type Array = ScalarFnArray;
5864
type Metadata = ScalarFnMetadata;

vortex-array/src/expr/exprs/scalar_fn.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@ use std::sync::Arc;
99
use itertools::Itertools;
1010
use vortex_dtype::DType;
1111
use vortex_error::VortexResult;
12-
use vortex_error::vortex_bail;
1312
use vortex_error::vortex_ensure;
1413
use vortex_session::SessionVar;
14+
use vortex_vector::Datum;
15+
use vortex_vector::ScalarOps;
1516
use vortex_vector::Vector;
17+
use vortex_vector::VectorMutOps;
1618

1719
use crate::ArrayRef;
20+
use crate::IntoArray;
21+
use crate::arrays::ScalarFnArray;
1822
use crate::expr::ChildName;
1923
use crate::expr::ExecutionArgs;
2024
use crate::expr::ExprId;
@@ -92,12 +96,27 @@ impl VTable for ScalarFnExpr {
9296
expr.data().return_dtype(&arg_dtypes)
9397
}
9498

95-
fn evaluate(&self, _expr: &ExpressionView<Self>, _scope: &ArrayRef) -> VortexResult<ArrayRef> {
96-
vortex_bail!("Scalar function evaluation not yet implemented")
99+
fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
100+
let children: Vec<_> = expr
101+
.children()
102+
.iter()
103+
.map(|child| child.evaluate(scope))
104+
.try_collect()?;
105+
Ok(ScalarFnArray::try_new(expr.data().clone(), children, scope.len())?.into_array())
97106
}
98107

99-
fn execute(&self, _data: &Self::Instance, _args: ExecutionArgs) -> VortexResult<Vector> {
100-
vortex_bail!("Scalar function execution not yet implemented")
108+
fn execute(&self, func: &ScalarFn, args: ExecutionArgs) -> VortexResult<Vector> {
109+
let expr_args = functions::ExecutionArgs::new(
110+
args.row_count,
111+
args.return_dtype,
112+
args.dtypes,
113+
args.vectors.into_iter().map(Datum::Vector).collect(),
114+
);
115+
let result = func.execute(&expr_args)?;
116+
Ok(match result {
117+
Datum::Scalar(s) => s.repeat(args.row_count).freeze(),
118+
Datum::Vector(v) => v,
119+
})
101120
}
102121

103122
fn stat_falsification(

0 commit comments

Comments
 (0)