Skip to content

Commit 682289f

Browse files
authored
Scalar functions (#5561)
Introduces the traits required to define abstract scalar functions and wrap them up as either expressions or arrays. This is one of the types of functions that we will split the generic ComputeFn trait into and allows us to lazily defer scalar compute in the array tree. We don't add them in this PR, but there may be some properties that are useful for ScalarFns in some form: ``` /// The identity element `e` where `f(e, x) = f(x, e) = x`. /// /// When an argument is the identity element, the function can be /// eliminated entirely, returning the other argument unchanged. /// /// # Examples /// - `AND`: `true` (AND(true, x) → x) /// - `OR`: `false` (OR(false, x) → x) /// - `+`: `0` (0 + x → x) /// - `*`: `1` (1 * x → x) /// - `COALESCE`: `NULL` (COALESCE(NULL, x) → x) fn identity_element(&self, options: &Self::Options) -> Option<Scalar> { _ = options; None } /// The absorbing element `a` where `f(a, x) = f(x, a) = a`. /// /// When any argument is the absorbing element, the function short-circuits /// immediately, returning that element without evaluating other arguments. /// Also known as the "annihilator" or "zero element". /// /// # Examples /// - `AND`: `false` (AND(false, x) → false) /// - `OR`: `true` (OR(true, x) → true) /// - `*`: `0` (0 * x → 0) fn absorbing_element(&self, options: &Self::Options) -> Option<Scalar> { _ = options; None } /// Whether argument order is irrelevant: `f(a, b) = f(b, a)`. /// /// Enables expression normalization (e.g., sorting arguments by column id) /// for better common subexpression elimination and pattern matching. /// /// # Examples /// - Commutative: `+`, `*`, `AND`, `OR`, `=`, `!=`, `MIN`, `MAX` /// - Non-commutative: `-`, `/`, `<`, `>`, `CONCAT` fn is_commutative(&self, options: &Self::Options) -> bool { _ = options; false } /// Whether `f(x, x) = x`. /// /// Enables simplification when the same expression appears multiple times /// as arguments to the function. /// /// # Examples /// - Idempotent: `AND`, `OR`, `MIN`, `MAX` /// - Non-idempotent: `+` (x + x = 2x), `*` (x * x = x²) fn is_idempotent(&self, options: &Self::Options) -> bool { _ = options; false } /// Whether `f(f(x)) = x` for unary functions. /// /// Enables cancellation of nested self-applications. /// /// # Examples /// - Involutions: `NOT`, `NEG` (for signed types), `REVERSE` /// - Non-involutions: `ABS`, `UPPER`, `LOWER` fn is_involution(&self, options: &Self::Options) -> bool { _ = options; false } /// How the function behaves when one or more arguments are NULL. /// /// Most functions propagate NULL (any NULL argument produces NULL output). /// Some functions have special NULL handling that can short-circuit /// evaluation or treat NULL as a meaningful value. /// /// Required for correct NULL semantics; may also enable optimizations /// when argument nullability is known from schema or statistics. fn null_handling(&self, options: &Self::Options) -> NullHandling { _ = options; NullHandling::default() } /// How a function handles NULL arguments. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub enum NullHandling { /// NULL in any argument produces NULL output. /// /// This is standard SQL behavior for most scalar functions. /// Enables simplification when any argument is known to be NULL. Propagate, /// NULL is short-circuited when paired with the absorbing element. /// /// This is a special case where the absorbing element "wins" over NULL. /// /// # Examples /// - `AND_KLEENE(false, NULL)` → `false` (false absorbs NULL) /// - `OR_KLEENE(true, NULL)` → `true` (true absorbs NULL) AbsorbsNull, /// The function has special NULL semantics that don't follow /// simple propagation rules. /// /// This prevents any simplifications based on NULL arguments. /// /// # Examples /// - `IS NULL`, `IS NOT NULL`: NULL → true/false /// - `COALESCE`: returns first non-NULL argument /// - `NULLIF`: conditionally produces NULL #[default] Custom, } ``` --------- Signed-off-by: Nicholas Gates <[email protected]>
1 parent 3877839 commit 682289f

File tree

36 files changed

+1363
-69
lines changed

36 files changed

+1363
-69
lines changed

vortex-array/src/arrays/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ mod listview;
2929
mod masked;
3030
mod null;
3131
mod primitive;
32+
mod scalar_fn;
3233
mod struct_;
3334
mod varbin;
3435
mod varbinview;
3536

3637
#[cfg(feature = "arbitrary")]
3738
pub mod arbitrary;
38-
3939
// TODO(connor): Export exact types, not glob.
4040

4141
pub use bool::*;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
6+
use crate::ArrayRef;
7+
use crate::expr::functions::scalar::ScalarFn;
8+
use crate::stats::ArrayStats;
9+
use crate::vtable::ArrayVTable;
10+
11+
#[derive(Clone, Debug)]
12+
pub struct ScalarFnArray {
13+
// NOTE(ngates): we should fix vtables so we don't have to hold this
14+
pub(super) vtable: ArrayVTable,
15+
pub(super) scalar_fn: ScalarFn,
16+
pub(super) dtype: DType,
17+
pub(super) len: usize,
18+
pub(super) children: Vec<ArrayRef>,
19+
pub(super) stats: ArrayStats,
20+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
6+
use crate::expr::functions::scalar::ScalarFn;
7+
8+
#[derive(Clone, Debug)]
9+
pub struct ScalarFnMetadata {
10+
pub(super) scalar_fn: ScalarFn,
11+
pub(super) child_dtypes: Vec<DType>,
12+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod array;
5+
mod metadata;
6+
mod vtable;
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::hash::Hash;
5+
use std::hash::Hasher;
6+
7+
use vortex_dtype::DType;
8+
9+
use crate::ArrayEq;
10+
use crate::ArrayHash;
11+
use crate::Precision;
12+
use crate::arrays::scalar_fn::array::ScalarFnArray;
13+
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
14+
use crate::stats::StatsSetRef;
15+
use crate::vtable::BaseArrayVTable;
16+
17+
impl BaseArrayVTable<ScalarFnVTable> for ScalarFnVTable {
18+
fn len(array: &ScalarFnArray) -> usize {
19+
array.len
20+
}
21+
22+
fn dtype(array: &ScalarFnArray) -> &DType {
23+
&array.dtype
24+
}
25+
26+
fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
27+
array.stats.to_ref(array.as_ref())
28+
}
29+
30+
fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
31+
array.len.hash(state);
32+
array.dtype.hash(state);
33+
array.scalar_fn.hash(state);
34+
for child in &array.children {
35+
child.array_hash(state, precision);
36+
}
37+
}
38+
39+
fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
40+
if array.len != other.len {
41+
return false;
42+
}
43+
if array.dtype != other.dtype {
44+
return false;
45+
}
46+
if array.scalar_fn != other.scalar_fn {
47+
return false;
48+
}
49+
for (child, other_child) in array.children.iter().zip(other.children.iter()) {
50+
if !child.array_eq(other_child, precision) {
51+
return false;
52+
}
53+
}
54+
true
55+
}
56+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use itertools::Itertools;
5+
use vortex_error::VortexExpect;
6+
use vortex_vector::Datum;
7+
8+
use crate::Array;
9+
use crate::Canonical;
10+
use crate::arrays::scalar_fn::array::ScalarFnArray;
11+
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
12+
use crate::expr::functions::ExecutionCtx;
13+
use crate::vectors::VectorIntoArray;
14+
use crate::vtable::CanonicalVTable;
15+
16+
impl CanonicalVTable<ScalarFnVTable> for ScalarFnVTable {
17+
fn canonicalize(array: &ScalarFnArray) -> Canonical {
18+
let child_dtypes: Vec<_> = array.children.iter().map(|c| c.dtype().clone()).collect();
19+
let child_datums: Vec<_> = array
20+
.children()
21+
.iter()
22+
// TODO(ngates): we could make all execution operate over datums
23+
.map(|child| child.execute().map(Datum::Vector))
24+
.try_collect()
25+
// FIXME(ngates): canonicalizing really ought to be fallible
26+
.vortex_expect(
27+
"Failed to execute child array during canonicalization of ScalarFnArray",
28+
);
29+
30+
let ctx = ExecutionCtx::new(array.len, array.dtype.clone(), child_dtypes, child_datums);
31+
32+
let result_vector = array
33+
.scalar_fn
34+
.execute(&ctx)
35+
.vortex_expect("Canonicalize should be fallible")
36+
.into_vector()
37+
.vortex_expect("Canonicalize should return a vector");
38+
39+
result_vector.into_array(&array.dtype).to_canonical()
40+
}
41+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod array;
5+
mod canonical;
6+
mod operations;
7+
mod validity;
8+
mod visitor;
9+
10+
use itertools::Itertools;
11+
use vortex_buffer::BufferHandle;
12+
use vortex_dtype::DType;
13+
use vortex_error::VortexExpect;
14+
use vortex_error::VortexResult;
15+
use vortex_error::vortex_bail;
16+
use vortex_vector::Vector;
17+
18+
use crate::Array;
19+
use crate::arrays::scalar_fn::array::ScalarFnArray;
20+
use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
21+
use crate::execution::ExecutionCtx;
22+
use crate::expr::functions;
23+
use crate::serde::ArrayChildren;
24+
use crate::vtable;
25+
use crate::vtable::ArrayId;
26+
use crate::vtable::ArrayVTable;
27+
use crate::vtable::ArrayVTableExt;
28+
use crate::vtable::NotSupported;
29+
use crate::vtable::VTable;
30+
31+
vtable!(ScalarFn);
32+
33+
#[derive(Clone, Debug)]
34+
pub struct ScalarFnVTable {
35+
vtable: functions::ScalarFnVTable,
36+
}
37+
38+
impl VTable for ScalarFnVTable {
39+
type Array = ScalarFnArray;
40+
type Metadata = ScalarFnMetadata;
41+
type ArrayVTable = Self;
42+
type CanonicalVTable = Self;
43+
type OperationsVTable = NotSupported;
44+
type ValidityVTable = Self;
45+
type VisitorVTable = Self;
46+
type ComputeVTable = NotSupported;
47+
type EncodeVTable = NotSupported;
48+
type OperatorVTable = NotSupported;
49+
50+
fn id(&self) -> ArrayId {
51+
self.vtable.id()
52+
}
53+
54+
fn encoding(array: &Self::Array) -> ArrayVTable {
55+
array.vtable.clone()
56+
}
57+
58+
fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
59+
let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
60+
Ok(ScalarFnMetadata {
61+
scalar_fn: array.scalar_fn.clone(),
62+
child_dtypes,
63+
})
64+
}
65+
66+
fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
67+
// Not supported
68+
Ok(None)
69+
}
70+
71+
fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
72+
vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
73+
}
74+
75+
fn build(
76+
&self,
77+
dtype: &DType,
78+
len: usize,
79+
metadata: &ScalarFnMetadata,
80+
_buffers: &[BufferHandle],
81+
children: &dyn ArrayChildren,
82+
) -> VortexResult<Self::Array> {
83+
let children: Vec<_> = metadata
84+
.child_dtypes
85+
.iter()
86+
.enumerate()
87+
.map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
88+
.try_collect()?;
89+
90+
#[cfg(debug_assertions)]
91+
{
92+
let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
93+
vortex_error::vortex_ensure!(
94+
&metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
95+
"Return dtype mismatch when building ScalarFnArray"
96+
);
97+
}
98+
99+
Ok(ScalarFnArray {
100+
// This requires a new Arc, but we plan to remove this later anyway.
101+
vtable: self.to_vtable(),
102+
scalar_fn: metadata.scalar_fn.clone(),
103+
dtype: dtype.clone(),
104+
len,
105+
children,
106+
stats: Default::default(),
107+
})
108+
}
109+
110+
fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
111+
let input_dtypes: Vec<_> = array.children().iter().map(|c| c.dtype().clone()).collect();
112+
let input_datums = array
113+
.children()
114+
.iter()
115+
.map(|child| child.execute())
116+
.try_collect()?;
117+
let ctx = functions::ExecutionCtx::new(
118+
array.len(),
119+
array.dtype.clone(),
120+
input_dtypes,
121+
input_datums,
122+
);
123+
Ok(array
124+
.scalar_fn
125+
.execute(&ctx)?
126+
.into_vector()
127+
.vortex_expect("Vector inputs should return vector outputs"))
128+
}
129+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ops::Range;
5+
6+
use vortex_error::VortexExpect;
7+
use vortex_scalar::Scalar;
8+
use vortex_vector::Datum;
9+
10+
use crate::ArrayRef;
11+
use crate::IntoArray;
12+
use crate::arrays::scalar_fn::array::ScalarFnArray;
13+
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
14+
use crate::expr::functions::ExecutionCtx;
15+
use crate::vtable::OperationsVTable;
16+
17+
impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
18+
fn slice(array: &ScalarFnArray, range: Range<usize>) -> ArrayRef {
19+
let children: Vec<_> = array
20+
.children()
21+
.iter()
22+
.map(|c| c.slice(range.clone()))
23+
.collect();
24+
25+
ScalarFnArray {
26+
vtable: array.vtable.clone(),
27+
scalar_fn: array.scalar_fn.clone(),
28+
dtype: array.dtype.clone(),
29+
len: range.len(),
30+
children,
31+
stats: Default::default(),
32+
}
33+
.into_array()
34+
}
35+
36+
fn scalar_at(array: &ScalarFnArray, index: usize) -> Scalar {
37+
// TODO(ngates): we should evaluate the scalar function over the scalar inputs.
38+
let input_datums: Vec<_> = array
39+
.children()
40+
.iter()
41+
.map(|c| c.scalar_at(index))
42+
.map(|scalar| Datum::from(scalar.to_vector_scalar()))
43+
.collect();
44+
45+
let ctx = ExecutionCtx::new(
46+
1,
47+
array.dtype.clone(),
48+
array.children().iter().map(|s| s.dtype().clone()).collect(),
49+
input_datums,
50+
);
51+
52+
let _result = array
53+
.scalar_fn
54+
.execute(&ctx)
55+
.vortex_expect("Scalar function execution should be fallible")
56+
.into_scalar()
57+
.vortex_expect("Scalar function execution should return scalar");
58+
59+
// Convert the vector scalar back into a legacy Scalar for now.
60+
todo!("Implement legacy scalar conversion")
61+
}
62+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexExpect;
5+
use vortex_mask::Mask;
6+
7+
use crate::Array;
8+
use crate::arrays::scalar_fn::array::ScalarFnArray;
9+
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
10+
use crate::expr::functions::NullHandling;
11+
use crate::vtable::ValidityVTable;
12+
13+
impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
14+
fn is_valid(array: &ScalarFnArray, index: usize) -> bool {
15+
array.scalar_at(index).is_valid()
16+
}
17+
18+
fn all_valid(array: &ScalarFnArray) -> bool {
19+
match array.scalar_fn.signature().null_handling() {
20+
NullHandling::Propagate | NullHandling::AbsorbsNull => {
21+
// Requires all children to guarantee all_valid
22+
array.children().iter().all(|child| child.all_valid())
23+
}
24+
NullHandling::Custom => {
25+
// We cannot guarantee that the array is all valid without evaluating the function
26+
false
27+
}
28+
}
29+
}
30+
31+
fn all_invalid(array: &ScalarFnArray) -> bool {
32+
match array.scalar_fn.signature().null_handling() {
33+
NullHandling::Propagate => {
34+
// All null if any child is all null
35+
array.children().iter().any(|child| child.all_invalid())
36+
}
37+
NullHandling::AbsorbsNull | NullHandling::Custom => {
38+
// We cannot guarantee that the array is all valid without evaluating the function
39+
false
40+
}
41+
}
42+
}
43+
44+
fn validity_mask(array: &ScalarFnArray) -> Mask {
45+
let vector = array
46+
.execute()
47+
.vortex_expect("Validity mask computation should be fallible");
48+
Mask::from_buffer(vector.into_bool().into_bits())
49+
}
50+
}

0 commit comments

Comments
 (0)