@@ -55,6 +55,8 @@ use crate::expr::stats::Precision;
5555use crate :: expr:: stats:: Stat ;
5656use crate :: expr:: stats:: StatsProviderExt ;
5757use crate :: hash;
58+ use crate :: kernel:: KernelRef ;
59+ use crate :: kernel:: ValidateKernel ;
5860use crate :: serde:: ArrayChildren ;
5961use crate :: stats:: StatsSetRef ;
6062use crate :: vtable:: ArrayId ;
@@ -194,7 +196,7 @@ pub trait Array:
194196 -> VortexResult < Option < Output > > ;
195197
196198 /// Invoke the batch execution function for the array to produce a canonical vector.
197- fn batch_execute ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < Vector > ;
199+ fn bind_kernel ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < KernelRef > ;
198200}
199201
200202impl Array for Arc < dyn Array > {
@@ -302,8 +304,8 @@ impl Array for Arc<dyn Array> {
302304 self . as_ref ( ) . invoke ( compute_fn, args)
303305 }
304306
305- fn batch_execute ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < Vector > {
306- self . as_ref ( ) . batch_execute ( ctx)
307+ fn bind_kernel ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < KernelRef > {
308+ self . as_ref ( ) . bind_kernel ( ctx)
307309 }
308310}
309311
@@ -377,7 +379,11 @@ impl dyn Array + '_ {
377379 pub fn execute ( & self , session : & VortexSession ) -> VortexResult < Vector > {
378380 let mut ctx = ExecutionCtx :: new ( session. clone ( ) ) ;
379381
380- let result = self . batch_execute ( & mut ctx) ?;
382+ // NOTE(ngates): in the future we can choose a different mode of execution, or run
383+ // optimization here, etc.
384+ let kernel = self . bind_kernel ( & mut ctx) ?;
385+ let result = kernel. execute ( ) ?;
386+
381387 vortex_ensure ! (
382388 result. len( ) == self . len( ) ,
383389 "Result length mismatch for {}" ,
@@ -698,18 +704,17 @@ impl<V: VTable> Array for ArrayAdapter<V> {
698704 <V :: ComputeVTable as ComputeVTable < V > >:: invoke ( & self . 0 , compute_fn, args)
699705 }
700706
701- fn batch_execute ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < Vector > {
702- let result = V :: batch_execute ( & self . 0 , ctx) ?;
703-
704- // This check is so cheap we always run it. Whereas DType checks we only do in debug builds.
705- vortex_ensure ! ( result. len( ) == self . len( ) , "Result length mismatch" ) ;
706- #[ cfg( debug_assertions) ]
707- vortex_ensure ! (
708- vortex_vector:: vector_matches_dtype( & result, self . dtype( ) ) ,
709- "Executed vector dtype mismatch" ,
710- ) ;
711-
712- Ok ( result)
707+ fn bind_kernel ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < KernelRef > {
708+ let kernel = V :: bind_kernel ( & self . 0 , ctx) ?;
709+ if cfg ! ( debug_assertions) {
710+ Ok ( Box :: new ( ValidateKernel :: new (
711+ kernel,
712+ self . dtype ( ) . clone ( ) ,
713+ self . len ( ) ,
714+ ) ) )
715+ } else {
716+ Ok ( kernel)
717+ }
713718 }
714719}
715720
0 commit comments