@@ -9,18 +9,18 @@ use std::hash::Hash;
99use std:: ops:: Deref ;
1010use std:: sync:: Arc ;
1111
12- use itertools:: Itertools ;
1312use vortex_dtype:: DType ;
1413use vortex_error:: VortexExpect ;
1514use vortex_error:: VortexResult ;
1615use vortex_error:: vortex_ensure;
1716
1817use crate :: ArrayRef ;
18+ use crate :: expr:: ReturnDTypeCtx ;
1919use crate :: expr:: Root ;
2020use crate :: expr:: StatsCatalog ;
2121use crate :: expr:: VTable ;
22- use crate :: expr:: bound:: BoundExpression ;
2322use crate :: expr:: display:: DisplayTreeExpr ;
23+ use crate :: expr:: scalar_fn:: ScalarFn ;
2424use crate :: expr:: stats:: Stat ;
2525
2626/// A node in a Vortex expression tree.
@@ -29,36 +29,39 @@ use crate::expr::stats::Stat;
2929/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
3030#[ derive( Clone , Debug , PartialEq , Eq , Hash ) ]
3131pub struct Expression {
32- /// The bound expression for this node.
33- bound : BoundExpression ,
32+ /// The scalar function for this node.
33+ scalar_fn : ScalarFn ,
3434 /// Any children of this expression.
3535 children : Arc < [ Expression ] > ,
3636}
3737
3838impl Deref for Expression {
39- type Target = BoundExpression ;
39+ type Target = ScalarFn ;
4040
4141 fn deref ( & self ) -> & Self :: Target {
42- & self . bound
42+ & self . scalar_fn
4343 }
4444}
4545
4646impl Expression {
47- /// Create a new expression node from a bound expression and its children.
47+ /// Create a new expression node from a scalar function and children.
4848 pub fn try_new (
49- bound : BoundExpression ,
49+ scalar_fn : ScalarFn ,
5050 children : impl Into < Arc < [ Expression ] > > ,
5151 ) -> VortexResult < Self > {
5252 let children: Arc < [ Expression ] > = children. into ( ) ;
5353
5454 vortex_ensure ! (
55- bound . signature( ) . arity( ) . matches( children. len( ) ) ,
55+ scalar_fn . signature( ) . arity( ) . matches( children. len( ) ) ,
5656 "Expression arity mismatch: expected {} children but got {}" ,
57- bound . signature( ) . arity( ) ,
57+ scalar_fn . signature( ) . arity( ) ,
5858 children. len( )
5959 ) ;
6060
61- Ok ( Self { bound, children } )
61+ Ok ( Self {
62+ scalar_fn,
63+ children,
64+ } )
6265 }
6366
6467 /// Returns true if this expression is of the given vtable type.
@@ -102,24 +105,18 @@ impl Expression {
102105
103106 /// Computes the return dtype of this expression given the input dtype.
104107 pub fn return_dtype ( & self , scope : & DType ) -> VortexResult < DType > {
105- if self . is :: < Root > ( ) {
106- return Ok ( scope. clone ( ) ) ;
107- }
108-
109- let dtypes: Vec < _ > = self
110- . children
111- . iter ( )
112- . map ( |c| c. return_dtype ( scope) )
113- . try_collect ( ) ?;
114- self . bound . return_dtype ( & dtypes)
108+ self . scalar_fn . return_dtype ( & ExpressionReturnDTypeCtx {
109+ expr : self ,
110+ scope_dtype : scope,
111+ } )
115112 }
116113
117114 /// Evaluates the expression in the given scope, returning an array.
118115 pub fn evaluate ( & self , scope : & ArrayRef ) -> VortexResult < ArrayRef > {
119116 if self . is :: < Root > ( ) {
120117 return Ok ( scope. clone ( ) ) ;
121118 }
122- self . bound . evaluate ( self , scope)
119+ self . scalar_fn . evaluate ( self , scope)
123120 }
124121
125122 /// An expression over zone-statistics which implies all records in the zone evaluate to false.
@@ -237,3 +234,29 @@ impl Display for Expression {
237234 self . fmt_sql ( f)
238235 }
239236}
237+
238+ pub ( super ) struct ExpressionReturnDTypeCtx < ' a > {
239+ pub ( super ) expr : & ' a Expression ,
240+ pub ( super ) scope_dtype : & ' a DType ,
241+ }
242+
243+ impl ReturnDTypeCtx for ExpressionReturnDTypeCtx < ' _ > {
244+ fn child_count ( & self ) -> usize {
245+ self . expr . children ( ) . len ( )
246+ }
247+
248+ fn return_dtype ( & self , child_idx : usize ) -> VortexResult < DType > {
249+ let child = & self . expr . children ( ) [ child_idx] ;
250+
251+ if child. is :: < Root > ( ) {
252+ return Ok ( self . scope_dtype . clone ( ) ) ;
253+ }
254+
255+ let ctx = ExpressionReturnDTypeCtx {
256+ expr : child,
257+ scope_dtype : self . scope_dtype ,
258+ } ;
259+ let child_fn: & ScalarFn = child. deref ( ) ;
260+ child_fn. return_dtype ( & ctx)
261+ }
262+ }
0 commit comments