Skip to content

Commit 50ddc86

Browse files
feat[vortex-array]: expr array that represents lazy computation (#5400)
Added an ExprArray and a few reduce rules ---- Signed-off-by: Joe Isaacs <[email protected]>
1 parent 4321359 commit 50ddc86

File tree

13 files changed

+914
-2
lines changed

13 files changed

+914
-2
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::DType;
5+
use vortex_error::{VortexResult, vortex_ensure};
6+
7+
use crate::expr::Expression;
8+
use crate::stats::ArrayStats;
9+
use crate::{Array, ArrayRef};
10+
11+
/// A array that represents an expression to be evaluated over a child array.
12+
///
13+
/// `ExprArray` enables deferred evaluation of expressions by wrapping a child array
14+
/// with an expression that operates on it. The expression is not evaluated until the
15+
/// array is canonicalized/executed.
16+
///
17+
/// # Examples
18+
///
19+
/// ```ignore
20+
/// // Create an expression that filters an integer array
21+
/// let data = PrimitiveArray::from_iter([1, 2, 3, 4, 5]);
22+
/// let expr = gt(root(), lit(3)); // $ > 3
23+
/// let expr_array = ExprArray::new_infer_dtype(data.into_array(), expr)?;
24+
///
25+
/// // The expression is evaluated when canonicalized
26+
/// let result = expr_array.to_canonical(); // Returns BoolArray([false, false, false, true, true])
27+
/// ```
28+
///
29+
/// # Type Safety
30+
///
31+
/// The `dtype` field must match `expr.return_dtype(child.dtype())`. This invariant
32+
/// is enforced by the safe constructors ([`try_new`](ExprArray::try_new) and
33+
/// [`new_infer_dtype`](ExprArray::new_infer_dtype)) but can be bypassed
34+
/// with [`unchecked_new`](ExprArray::unchecked_new) for performance-critical code.
35+
#[derive(Clone, Debug)]
36+
pub struct ExprArray {
37+
/// The underlying array that the expression will operate on.
38+
pub(super) child: ArrayRef,
39+
/// The expression to evaluate over the child array.
40+
pub(super) expr: Expression,
41+
/// The data type of the result after evaluating the expression.
42+
pub(super) dtype: DType,
43+
/// Statistics about the resulting array (may be computed lazily).
44+
pub(super) stats: ArrayStats,
45+
}
46+
47+
impl ExprArray {
48+
/// Creates a new ExprArray with the dtype validated to match the expression's return type.
49+
pub fn try_new(child: ArrayRef, expr: Expression, dtype: DType) -> VortexResult<Self> {
50+
let expected_dtype = expr.return_dtype(child.dtype())?;
51+
vortex_ensure!(
52+
dtype == expected_dtype,
53+
"ExprArray dtype mismatch: expected {}, got {}",
54+
expected_dtype,
55+
dtype
56+
);
57+
Ok(unsafe { Self::unchecked_new(child, expr, dtype) })
58+
}
59+
60+
/// Create a new ExprArray without validating that the dtype matches the expression's return type.
61+
///
62+
/// # Safety
63+
///
64+
/// The caller must ensure that `dtype` matches `expr.return_dtype(child.dtype())`.
65+
/// Violating this invariant may lead to incorrect results or panics when the array is used.
66+
pub unsafe fn unchecked_new(child: ArrayRef, expr: Expression, dtype: DType) -> Self {
67+
Self {
68+
child,
69+
expr,
70+
dtype,
71+
// TODO(joe): Propagate or compute statistics from the child array and expression.
72+
stats: ArrayStats::default(),
73+
}
74+
}
75+
76+
/// Creates a new ExprArray with the dtype inferred from the expression and child.
77+
pub fn new_infer_dtype(child: ArrayRef, expr: Expression) -> VortexResult<Self> {
78+
let dtype = expr.return_dtype(child.dtype())?;
79+
Ok(unsafe { Self::unchecked_new(child, expr, dtype) })
80+
}
81+
82+
pub fn child(&self) -> &ArrayRef {
83+
&self.child
84+
}
85+
86+
pub fn expr(&self) -> &Expression {
87+
&self.expr
88+
}
89+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod array;
5+
pub use array::ExprArray;
6+
7+
mod vtable;
8+
pub use vtable::{ExprEncoding, ExprVTable};
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::hash::Hash;
5+
6+
use vortex_dtype::DType;
7+
8+
use crate::Precision;
9+
use crate::arrays::expr::{ExprArray, ExprVTable};
10+
use crate::hash::{ArrayEq, ArrayHash};
11+
use crate::stats::StatsSetRef;
12+
use crate::vtable::ArrayVTable;
13+
14+
impl ArrayVTable<ExprVTable> for ExprVTable {
15+
fn len(array: &ExprArray) -> usize {
16+
array.child.len()
17+
}
18+
19+
fn dtype(array: &ExprArray) -> &DType {
20+
&array.dtype
21+
}
22+
23+
fn stats(array: &ExprArray) -> StatsSetRef<'_> {
24+
array.stats.to_ref(array.as_ref())
25+
}
26+
27+
fn array_hash<H: std::hash::Hasher>(array: &ExprArray, state: &mut H, precision: Precision) {
28+
array.child.array_hash(state, precision);
29+
array.dtype.hash(state);
30+
array.expr.hash(state)
31+
}
32+
33+
fn array_eq(array: &ExprArray, other: &ExprArray, precision: Precision) -> bool {
34+
array.child.array_eq(&other.child, precision)
35+
&& array.dtype == other.dtype
36+
&& array.expr == other.expr
37+
}
38+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexExpect;
5+
6+
use crate::Canonical;
7+
use crate::arrays::expr::{ExprArray, ExprVTable};
8+
use crate::vtable::CanonicalVTable;
9+
10+
impl CanonicalVTable<ExprVTable> for ExprVTable {
11+
fn canonicalize(array: &ExprArray) -> Canonical {
12+
array
13+
.expr
14+
.evaluate(&array.child)
15+
.vortex_expect("Failed to evaluate expression")
16+
.to_canonical()
17+
}
18+
}
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use vortex_buffer::buffer;
23+
use vortex_dtype::Nullability::NonNullable;
24+
use vortex_dtype::{DType, PType};
25+
26+
use crate::arrays::expr::ExprArray;
27+
use crate::arrays::primitive::PrimitiveArray;
28+
use crate::expr::binary::checked_add;
29+
use crate::expr::literal::lit;
30+
use crate::validity::Validity;
31+
use crate::{Array, IntoArray, assert_arrays_eq};
32+
33+
#[test]
34+
fn test_expr_array_canonicalize() {
35+
let child = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::NonNullable).into_array();
36+
37+
// This expression doesn't use the child, but demonstrates the ExprArray mechanics
38+
let expr = checked_add(lit(10), lit(5));
39+
40+
let dtype = DType::Primitive(PType::I32, NonNullable);
41+
let expr_array = ExprArray::try_new(child, expr, dtype).unwrap();
42+
43+
let actual = expr_array.to_canonical().into_array();
44+
45+
let expect = (0..3).map(|_| 15i32).collect::<PrimitiveArray>();
46+
assert_arrays_eq!(expect, actual);
47+
}
48+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod array;
5+
mod canonical;
6+
mod operations;
7+
mod operator;
8+
mod visitor;
9+
10+
use std::fmt::Debug;
11+
12+
use vortex_buffer::ByteBuffer;
13+
use vortex_dtype::DType;
14+
use vortex_error::{VortexResult, vortex_bail};
15+
16+
use crate::arrays::expr::ExprArray;
17+
use crate::expr::Expression;
18+
use crate::serde::ArrayChildren;
19+
use crate::vtable::{NotSupported, VTable};
20+
use crate::{EncodingId, EncodingRef, vtable};
21+
22+
vtable!(Expr);
23+
24+
#[derive(Clone, Debug)]
25+
pub struct ExprEncoding;
26+
27+
impl VTable for ExprVTable {
28+
type Array = ExprArray;
29+
type Encoding = ExprEncoding;
30+
type Metadata = ExprArrayMetadata;
31+
32+
type ArrayVTable = Self;
33+
type CanonicalVTable = Self;
34+
type OperationsVTable = Self;
35+
type ValidityVTable = NotSupported;
36+
type VisitorVTable = Self;
37+
type ComputeVTable = NotSupported;
38+
type EncodeVTable = NotSupported;
39+
type OperatorVTable = Self;
40+
41+
fn id(_encoding: &Self::Encoding) -> EncodingId {
42+
EncodingId::new_ref("vortex.expr")
43+
}
44+
45+
fn encoding(_array: &Self::Array) -> EncodingRef {
46+
EncodingRef::new_ref(ExprEncoding.as_ref())
47+
}
48+
49+
fn metadata(array: &ExprArray) -> VortexResult<Self::Metadata> {
50+
Ok(ExprArrayMetadata((array.expr.clone(), array.dtype.clone())))
51+
}
52+
53+
fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
54+
Ok(None)
55+
}
56+
57+
fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
58+
vortex_bail!("unsupported")
59+
}
60+
61+
fn build(
62+
_encoding: &ExprEncoding,
63+
dtype: &DType,
64+
len: usize,
65+
ExprArrayMetadata((expr, root_dtype)): &Self::Metadata,
66+
buffers: &[ByteBuffer],
67+
children: &dyn ArrayChildren,
68+
) -> VortexResult<ExprArray> {
69+
if !buffers.is_empty() {
70+
vortex_bail!("Expected 0 buffers, got {}", buffers.len());
71+
}
72+
73+
let Ok(child) = children.get(0, root_dtype, len) else {
74+
vortex_bail!("Expected 1 child, got {}", children.len());
75+
};
76+
77+
ExprArray::try_new(child, expr.clone(), dtype.clone())
78+
}
79+
}
80+
81+
pub struct ExprArrayMetadata((Expression, DType));
82+
83+
impl Debug for ExprArrayMetadata {
84+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85+
// Since this is used in display method we can omit the dtype.
86+
self.0.0.fmt_sql(f)
87+
}
88+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ops::Range;
5+
6+
use vortex_error::VortexExpect;
7+
use vortex_scalar::Scalar;
8+
9+
use crate::arrays::ConstantArray;
10+
use crate::arrays::expr::{ExprArray, ExprVTable};
11+
use crate::stats::ArrayStats;
12+
use crate::vtable::OperationsVTable;
13+
use crate::{Array, ArrayRef, IntoArray};
14+
15+
impl OperationsVTable<ExprVTable> for ExprVTable {
16+
fn slice(array: &ExprArray, range: Range<usize>) -> ArrayRef {
17+
let child = array.child.slice(range);
18+
19+
ExprArray {
20+
child,
21+
expr: array.expr.clone(),
22+
dtype: array.dtype.clone(),
23+
stats: ArrayStats::default(),
24+
}
25+
.into_array()
26+
}
27+
28+
fn scalar_at(array: &ExprArray, index: usize) -> Scalar {
29+
// TODO(joe): this is unchecked
30+
array
31+
.expr
32+
.evaluate(&ConstantArray::new(array.child.scalar_at(index), 1).into_array())
33+
.vortex_expect("cannot fail")
34+
.as_constant()
35+
.vortex_expect("expr are scalar so cannot fail")
36+
}
37+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::ArrayRef;
7+
use crate::arrays::expr::{ExprArray, ExprVTable};
8+
use crate::expr::root;
9+
use crate::expr::session::ExprSession;
10+
use crate::expr::transform::ExprOptimizer;
11+
use crate::vtable::OperatorVTable;
12+
13+
impl OperatorVTable<ExprVTable> for ExprVTable {
14+
fn reduce(array: &ExprArray) -> VortexResult<Option<ArrayRef>> {
15+
// Get the default expression session
16+
let session = ExprSession::default();
17+
let optimizer = ExprOptimizer::new(&session);
18+
19+
// Try to optimize the expression with type information
20+
let optimized_expr =
21+
optimizer.optimize_typed(array.expr().clone(), array.child().dtype())?;
22+
23+
if optimized_expr != *array.expr() {
24+
// If the expression simplified to just root(), return the child directly
25+
if optimized_expr == root() {
26+
return Ok(Some(array.child().clone()));
27+
}
28+
29+
let new_dtype = optimized_expr.return_dtype(array.child().dtype())?;
30+
Ok(Some(
31+
ExprArray::try_new(array.child().clone(), optimized_expr, new_dtype)?.into(),
32+
))
33+
} else {
34+
Ok(None)
35+
}
36+
}
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
42+
use vortex_dtype::Nullability;
43+
use vortex_error::VortexExpect;
44+
45+
use super::*;
46+
use crate::IntoArray;
47+
use crate::arrays::{PrimitiveArray, PrimitiveVTable};
48+
use crate::expr::{get_item, pack, root};
49+
50+
#[test]
51+
fn test_expr_array_reduce_pack_unpack() -> VortexResult<()> {
52+
let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
53+
54+
let expr = get_item("a", pack([("a", root())], Nullability::NonNullable));
55+
56+
let expr_array = ExprArray::new_infer_dtype(array.into_array(), expr)?;
57+
58+
// Call reduce - it should optimize pack(a: $).a to just $
59+
let reduced = expr_array.reduce()?.vortex_expect("reduce failed");
60+
61+
assert!(reduced.is::<PrimitiveVTable>());
62+
63+
Ok(())
64+
}
65+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use crate::arrays::expr::{ExprArray, ExprVTable};
5+
use crate::vtable::VisitorVTable;
6+
use crate::{ArrayBufferVisitor, ArrayChildVisitor};
7+
8+
impl VisitorVTable<ExprVTable> for ExprVTable {
9+
fn visit_buffers(_array: &ExprArray, _visitor: &mut dyn ArrayBufferVisitor) {}
10+
11+
fn visit_children(array: &ExprArray, visitor: &mut dyn ArrayChildVisitor) {
12+
visitor.visit_child("child", array.child.as_ref());
13+
}
14+
}

0 commit comments

Comments
 (0)