Skip to content

Commit 8f9dbc5

Browse files
authored
Combine ScalarFnVTable and ExprVTable (#5616)
We proved the idea out with ScalarFns, but realised we can minimize the diff for end-users if we actually just re-purpose an expression to == a scalar function. It pretty much was already anyway. This PR is a step on the way to having Expression -> Array be a constant-time operation, with actual compute deferred until we execute the array tree. # Breaking Changes * Expression VTable changes * `Instance` renamed to `Options` to make the purpose more obvious * Added an `Arity` function to replace `validate`, since the only possible validation in the absence of types is arity based anyway. * Added an optional `execute` function that will replace `evaluate` in the future. * Removed `fmt_data` in favor of requiring a Display bound on the associated Options type. * `return_dtype` takes a `ReturnDTypeCtx` to avoid eagerly computing child dtypes. * Removed ExpressionView in favor of returning the options from `Expression::as_` and `Expression::as_opt`. Other changes that I don't expect will impact users: * Moved `is_null_sensitive` and `child_name` from `Expression`, to `ExpressionSignature` * Moved `Expression::serialize_metadata()` to `Expression::options().serialize()` * Removed ExprOptimizer, in favor of `VTable::simplify`. Any more complex optimizations can be implemented in future over the Array tree prior to execution. --------- Signed-off-by: Nicholas Gates <[email protected]>
1 parent 1d907be commit 8f9dbc5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+2135
-3931
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/arrays/scalar_fn/array.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use vortex_error::vortex_ensure;
88
use crate::Array;
99
use crate::ArrayRef;
1010
use crate::arrays::ScalarFnVTable;
11-
use crate::expr::functions::scalar::ScalarFn;
11+
use crate::expr::ScalarFn;
1212
use crate::stats::ArrayStats;
1313
use crate::vtable::ArrayVTable;
1414
use crate::vtable::ArrayVTableExt;
@@ -17,7 +17,7 @@ use crate::vtable::ArrayVTableExt;
1717
pub struct ScalarFnArray {
1818
// NOTE(ngates): we should fix vtables so we don't have to hold this
1919
pub(super) vtable: ArrayVTable,
20-
pub(super) scalar_fn: ScalarFn,
20+
pub(super) bound: ScalarFn,
2121
pub(super) dtype: DType,
2222
pub(super) len: usize,
2323
pub(super) children: Vec<ArrayRef>,
@@ -26,18 +26,18 @@ pub struct ScalarFnArray {
2626

2727
impl ScalarFnArray {
2828
/// Create a new ScalarFnArray from a scalar function and its children.
29-
pub fn try_new(scalar_fn: ScalarFn, children: Vec<ArrayRef>, len: usize) -> VortexResult<Self> {
29+
pub fn try_new(bound: ScalarFn, children: Vec<ArrayRef>, len: usize) -> VortexResult<Self> {
3030
let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
31-
let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
31+
let dtype = bound.return_dtype(&arg_dtypes)?;
3232

3333
vortex_ensure!(
3434
children.iter().all(|c| c.len() == len),
3535
"ScalarFnArray must have children equal to the array length"
3636
);
3737

3838
Ok(Self {
39-
vtable: ScalarFnVTable::new(scalar_fn.vtable().clone()).into_vtable(),
40-
scalar_fn,
39+
vtable: ScalarFnVTable::new(bound.vtable().clone()).into_vtable(),
40+
bound,
4141
dtype,
4242
len,
4343
children,

vortex-array/src/arrays/scalar_fn/metadata.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
use vortex_dtype::DType;
55

6-
use crate::expr::functions::scalar::ScalarFn;
6+
use crate::expr::ScalarFn;
77

88
#[derive(Clone, Debug)]
99
pub struct ScalarFnMetadata {
10-
pub(super) scalar_fn: ScalarFn,
10+
pub(super) bound: ScalarFn,
1111
pub(super) child_dtypes: Vec<DType>,
1212
}

vortex-array/src/arrays/scalar_fn/vtable/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl BaseArrayVTable<ScalarFnVTable> for ScalarFnVTable {
3030
fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
3131
array.len.hash(state);
3232
array.dtype.hash(state);
33-
array.scalar_fn.hash(state);
33+
array.bound.hash(state);
3434
for child in &array.children {
3535
child.array_hash(state, precision);
3636
}
@@ -43,7 +43,7 @@ impl BaseArrayVTable<ScalarFnVTable> for ScalarFnVTable {
4343
if array.dtype != other.dtype {
4444
return false;
4545
}
46-
if array.scalar_fn != other.scalar_fn {
46+
if array.bound != other.bound {
4747
return false;
4848
}
4949
for (child, other_child) in array.children.iter().zip(other.children.iter()) {

vortex-array/src/arrays/scalar_fn/vtable/canonical.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::Canonical;
1010
use crate::arrays::scalar_fn::array::ScalarFnArray;
1111
use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION;
1212
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
13-
use crate::expr::functions::ExecutionArgs;
13+
use crate::expr::ExecutionArgs;
1414
use crate::vectors::VectorIntoArray;
1515
use crate::vtable::CanonicalVTable;
1616

@@ -28,11 +28,16 @@ impl CanonicalVTable<ScalarFnVTable> for ScalarFnVTable {
2828
"Failed to execute child array during canonicalization of ScalarFnArray",
2929
);
3030

31-
let ctx = ExecutionArgs::new(array.len, array.dtype.clone(), child_dtypes, child_datums);
31+
let ctx = ExecutionArgs {
32+
datums: child_datums,
33+
dtypes: child_dtypes,
34+
row_count: array.len,
35+
return_dtype: array.dtype.clone(),
36+
};
3237

3338
let result_vector = array
34-
.scalar_fn
35-
.execute(&ctx)
39+
.bound
40+
.execute(ctx)
3641
.vortex_expect("Canonicalize should be fallible")
3742
.into_vector()
3843
.vortex_expect("Canonicalize should return a vector");

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use vortex_error::VortexResult;
1919
use vortex_error::vortex_bail;
2020
use vortex_error::vortex_ensure;
2121
use vortex_session::VortexSession;
22+
use vortex_vector::Datum;
2223
use vortex_vector::Vector;
2324

2425
use crate::Array;
@@ -27,8 +28,10 @@ use crate::IntoArray;
2728
use crate::arrays::scalar_fn::array::ScalarFnArray;
2829
use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
2930
use crate::execution::ExecutionCtx;
30-
use crate::expr::functions;
31-
use crate::expr::functions::scalar::ScalarFn;
31+
use crate::expr;
32+
use crate::expr::ExecutionArgs;
33+
use crate::expr::ExprVTable;
34+
use crate::expr::ScalarFn;
3235
use crate::optimizer::rules::MatchKey;
3336
use crate::optimizer::rules::Matcher;
3437
use crate::serde::ArrayChildren;
@@ -50,11 +53,11 @@ vtable!(ScalarFn);
5053

5154
#[derive(Clone, Debug)]
5255
pub struct ScalarFnVTable {
53-
vtable: functions::ScalarFnVTable,
56+
vtable: ExprVTable,
5457
}
5558

5659
impl ScalarFnVTable {
57-
pub fn new(vtable: functions::ScalarFnVTable) -> Self {
60+
pub fn new(vtable: ExprVTable) -> Self {
5861
Self { vtable }
5962
}
6063
}
@@ -81,7 +84,7 @@ impl VTable for ScalarFnVTable {
8184
fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
8285
let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
8386
Ok(ScalarFnMetadata {
84-
scalar_fn: array.scalar_fn.clone(),
87+
bound: array.bound.clone(),
8588
child_dtypes,
8689
})
8790
}
@@ -114,15 +117,15 @@ impl VTable for ScalarFnVTable {
114117
{
115118
let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
116119
vortex_error::vortex_ensure!(
117-
&metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
120+
&metadata.bound.return_dtype(&child_dtypes)? == dtype,
118121
"Return dtype mismatch when building ScalarFnArray"
119122
);
120123
}
121124

122125
Ok(ScalarFnArray {
123126
// This requires a new Arc, but we plan to remove this later anyway.
124127
vtable: self.to_vtable(),
125-
scalar_fn: metadata.scalar_fn.clone(),
128+
bound: metadata.bound.clone(),
126129
dtype: dtype.clone(),
127130
len,
128131
children,
@@ -135,31 +138,31 @@ impl VTable for ScalarFnVTable {
135138
let input_datums = array
136139
.children()
137140
.iter()
138-
.map(|child| child.batch_execute(ctx))
141+
.map(|child| child.batch_execute(ctx).map(Datum::Vector))
139142
.try_collect()?;
140-
let ctx = functions::ExecutionArgs::new(
141-
array.len(),
142-
array.dtype.clone(),
143-
input_dtypes,
144-
input_datums,
145-
);
143+
let ctx = ExecutionArgs {
144+
datums: input_datums,
145+
dtypes: input_dtypes,
146+
row_count: array.len,
147+
return_dtype: array.dtype.clone(),
148+
};
146149
Ok(array
147-
.scalar_fn
148-
.execute(&ctx)?
150+
.bound
151+
.execute(ctx)?
149152
.into_vector()
150153
.vortex_expect("Vector inputs should return vector outputs"))
151154
}
152155
}
153156

154157
/// Array factory functions for scalar functions.
155-
pub trait ScalarFnArrayExt: functions::VTable {
158+
pub trait ScalarFnArrayExt: expr::VTable {
156159
fn try_new_array(
157160
&'static self,
158161
len: usize,
159162
options: Self::Options,
160163
children: impl Into<Vec<ArrayRef>>,
161164
) -> VortexResult<ArrayRef> {
162-
let scalar_fn = ScalarFn::new_static(self, options);
165+
let bound = ScalarFn::new_static(self, options);
163166

164167
let children = children.into();
165168
vortex_ensure!(
@@ -168,16 +171,16 @@ pub trait ScalarFnArrayExt: functions::VTable {
168171
);
169172

170173
let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
171-
let dtype = scalar_fn.return_dtype(&child_dtypes)?;
174+
let dtype = bound.return_dtype(&child_dtypes)?;
172175

173176
let array_vtable: ArrayVTable = ScalarFnVTable {
174-
vtable: scalar_fn.vtable().clone(),
177+
vtable: bound.vtable().clone(),
175178
}
176179
.into_vtable();
177180

178181
Ok(ScalarFnArray {
179182
vtable: array_vtable,
180-
scalar_fn,
183+
bound,
181184
dtype,
182185
len,
183186
children,
@@ -186,7 +189,7 @@ pub trait ScalarFnArrayExt: functions::VTable {
186189
.into_array())
187190
}
188191
}
189-
impl<V: functions::VTable> ScalarFnArrayExt for V {}
192+
impl<V: expr::VTable> ScalarFnArrayExt for V {}
190193

191194
/// A matcher that matches any scalar function expression.
192195
#[derive(Debug)]
@@ -205,12 +208,12 @@ impl Matcher for AnyScalarFn {
205208

206209
/// A matcher that matches a specific scalar function expression.
207210
#[derive(Debug)]
208-
pub struct ExactScalarFn<F: functions::VTable> {
211+
pub struct ExactScalarFn<F: expr::VTable> {
209212
id: ArrayId,
210213
_phantom: PhantomData<F>,
211214
}
212215

213-
impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
216+
impl<F: expr::VTable> From<&'static F> for ExactScalarFn<F> {
214217
fn from(value: &'static F) -> Self {
215218
Self {
216219
id: value.id(),
@@ -219,7 +222,7 @@ impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
219222
}
220223
}
221224

222-
impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
225+
impl<F: expr::VTable> Matcher for ExactScalarFn<F> {
223226
type View<'a> = ScalarFnArrayView<'a, F>;
224227

225228
fn key(&self) -> MatchKey {
@@ -229,12 +232,12 @@ impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
229232
fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
230233
let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
231234
let scalar_fn_vtable = scalar_fn_array
232-
.scalar_fn
235+
.bound
233236
.vtable()
234237
.as_any()
235238
.downcast_ref::<F>()?;
236239
let scalar_fn_options = scalar_fn_array
237-
.scalar_fn
240+
.bound
238241
.options()
239242
.as_any()
240243
.downcast_ref::<F::Options>()?;
@@ -246,13 +249,13 @@ impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
246249
}
247250
}
248251

249-
pub struct ScalarFnArrayView<'a, F: functions::VTable> {
252+
pub struct ScalarFnArrayView<'a, F: expr::VTable> {
250253
array: &'a ArrayRef,
251254
pub vtable: &'a F,
252255
pub options: &'a F::Options,
253256
}
254257

255-
impl<F: functions::VTable> Deref for ScalarFnArrayView<'_, F> {
258+
impl<F: expr::VTable> Deref for ScalarFnArrayView<'_, F> {
256259
type Target = ArrayRef;
257260

258261
fn deref(&self) -> &Self::Target {

vortex-array/src/arrays/scalar_fn/vtable/operations.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::ArrayRef;
1111
use crate::IntoArray;
1212
use crate::arrays::scalar_fn::array::ScalarFnArray;
1313
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
14-
use crate::expr::functions::ExecutionArgs;
14+
use crate::expr::ExecutionArgs;
1515
use crate::vtable::OperationsVTable;
1616

1717
impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
@@ -24,7 +24,7 @@ impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
2424

2525
ScalarFnArray {
2626
vtable: array.vtable.clone(),
27-
scalar_fn: array.scalar_fn.clone(),
27+
bound: array.bound.clone(),
2828
dtype: array.dtype.clone(),
2929
len: range.len(),
3030
children,
@@ -42,16 +42,16 @@ impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
4242
.map(|scalar| Datum::from(scalar.to_vector_scalar()))
4343
.collect();
4444

45-
let ctx = ExecutionArgs::new(
46-
1,
47-
array.dtype.clone(),
48-
array.children().iter().map(|s| s.dtype().clone()).collect(),
49-
input_datums,
50-
);
45+
let ctx = ExecutionArgs {
46+
datums: input_datums,
47+
dtypes: array.children().iter().map(|c| c.dtype().clone()).collect(),
48+
row_count: 1,
49+
return_dtype: array.dtype.clone(),
50+
};
5151

5252
let _result = array
53-
.scalar_fn
54-
.execute(&ctx)
53+
.bound
54+
.execute(ctx)
5555
.vortex_expect("Scalar function execution should be fallible")
5656
.into_scalar()
5757
.vortex_expect("Scalar function execution should return scalar");

vortex-array/src/arrays/scalar_fn/vtable/validity.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use crate::Array;
88
use crate::arrays::scalar_fn::array::ScalarFnArray;
99
use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION;
1010
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
11-
use crate::expr::functions::NullHandling;
1211
use crate::vtable::ValidityVTable;
1312

1413
impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
@@ -17,28 +16,32 @@ impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
1716
}
1817

1918
fn all_valid(array: &ScalarFnArray) -> bool {
20-
match array.scalar_fn.signature().null_handling() {
21-
NullHandling::Propagate | NullHandling::AbsorbsNull => {
22-
// Requires all children to guarantee all_valid
23-
array.children().iter().all(|child| child.all_valid())
24-
}
25-
NullHandling::Custom => {
26-
// We cannot guarantee that the array is all valid without evaluating the function
19+
match array.bound.signature().is_null_sensitive() {
20+
true => {
21+
// If the function is null sensitive, we cannot guarantee all valid without evaluating
22+
// the function
2723
false
2824
}
25+
false => {
26+
// If the function is not null sensitive, we can guarantee all valid if all children
27+
// are all valid
28+
array.children().iter().all(|child| child.all_valid())
29+
}
2930
}
3031
}
3132

3233
fn all_invalid(array: &ScalarFnArray) -> bool {
33-
match array.scalar_fn.signature().null_handling() {
34-
NullHandling::Propagate => {
35-
// All null if any child is all null
36-
array.children().iter().any(|child| child.all_invalid())
37-
}
38-
NullHandling::AbsorbsNull | NullHandling::Custom => {
39-
// We cannot guarantee that the array is all valid without evaluating the function
34+
match array.bound.signature().is_null_sensitive() {
35+
true => {
36+
// If the function is null sensitive, we cannot guarantee all invalid without evaluating
37+
// the function
4038
false
4139
}
40+
false => {
41+
// If the function is not null sensitive, we can guarantee all invalid if any child
42+
// is all invalid
43+
array.children().iter().any(|child| child.all_invalid())
44+
}
4245
}
4346
}
4447

0 commit comments

Comments
 (0)