@@ -6,13 +6,11 @@ use std::sync::LazyLock;
66
77use arcref:: ArcRef ;
88use vortex_dtype:: DType ;
9- use vortex_error:: { VortexResult , vortex_err, vortex_panic} ;
9+ use vortex_error:: { VortexError , VortexResult , vortex_bail , vortex_err, vortex_panic} ;
1010use vortex_scalar:: Scalar ;
1111
1212use crate :: Array ;
13- use crate :: compute:: {
14- ComputeFn , ComputeFnVTable , InvocationArgs , Kernel , Options , Output , UnaryArgs ,
15- } ;
13+ use crate :: compute:: { ComputeFn , ComputeFnVTable , InvocationArgs , Kernel , Options , Output } ;
1614use crate :: stats:: { Precision , Stat , StatsProvider } ;
1715use crate :: vtable:: VTable ;
1816
@@ -45,11 +43,11 @@ pub(crate) fn warm_up_vtable() -> usize {
4543/// If the sum is not supported for the array's dtype, an error will be raised.
4644/// If the array is all-invalid, the sum will be the initial_value.
4745/// The initial_value must have a dtype compatible with the sum result dtype.
48- pub ( crate ) fn sum_with_initial ( array : & dyn Array , initial_value : Scalar ) -> VortexResult < Scalar > {
46+ pub ( crate ) fn sum_with_initial ( array : & dyn Array , initial_value : & Scalar ) -> VortexResult < Scalar > {
4947 SUM_FN
5048 . invoke ( & InvocationArgs {
51- inputs : & [ array. into ( ) ] ,
52- options : & SumOptions { initial_value } ,
49+ inputs : & [ array. into ( ) , initial_value . into ( ) ] ,
50+ options : & ( ) ,
5351 } ) ?
5452 . unwrap_scalar ( )
5553}
@@ -64,7 +62,30 @@ pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
6462 . dtype ( array. dtype ( ) )
6563 . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype: {}" , array. dtype( ) ) ) ?;
6664 let zero = Scalar :: zero_value ( sum_dtype) ;
67- sum_with_initial ( array, zero)
65+ sum_with_initial ( array, & zero)
66+ }
67+
68+ /// For unary compute functions, it's useful to just have this short-cut.
69+ pub struct SumArgs < ' a > {
70+ pub array : & ' a dyn Array ,
71+ pub accumulator : & ' a Scalar ,
72+ }
73+
74+ impl < ' a > TryFrom < & InvocationArgs < ' a > > for SumArgs < ' a > {
75+ type Error = VortexError ;
76+
77+ fn try_from ( value : & InvocationArgs < ' a > ) -> Result < Self , Self :: Error > {
78+ if value. inputs . len ( ) != 2 {
79+ vortex_bail ! ( "Expected 2 inputs, found {}" , value. inputs. len( ) ) ;
80+ }
81+ let array = value. inputs [ 0 ]
82+ . array ( )
83+ . ok_or_else ( || vortex_err ! ( "Expected input 0 to be an array" ) ) ?;
84+ let accumulator = value. inputs [ 1 ]
85+ . scalar ( )
86+ . ok_or_else ( || vortex_err ! ( "Expected input 1 to be a scalar" ) ) ?;
87+ Ok ( SumArgs { array, accumulator } )
88+ }
6889}
6990
7091struct Sum ;
@@ -75,8 +96,7 @@ impl ComputeFnVTable for Sum {
7596 args : & InvocationArgs ,
7697 kernels : & [ ArcRef < dyn Kernel > ] ,
7798 ) -> VortexResult < Output > {
78- let UnaryArgs { array, options } = UnaryArgs :: < SumOptions > :: try_from ( args) ?;
79- let initial_value = & options. initial_value ;
99+ let SumArgs { array, accumulator } = args. try_into ( ) ?;
80100
81101 // Compute the expected dtype of the sum.
82102 let sum_dtype = self . return_dtype ( args) ?;
@@ -86,7 +106,7 @@ impl ComputeFnVTable for Sum {
86106 return Ok ( sum. into ( ) ) ;
87107 }
88108
89- let sum_scalar = sum_impl ( array, sum_dtype, initial_value , kernels) ?;
109+ let sum_scalar = sum_impl ( array, sum_dtype, accumulator , kernels) ?;
90110
91111 // Update the statistics with the computed sum.
92112 array
@@ -97,7 +117,7 @@ impl ComputeFnVTable for Sum {
97117 }
98118
99119 fn return_dtype ( & self , args : & InvocationArgs ) -> VortexResult < DType > {
100- let UnaryArgs { array, .. } = UnaryArgs :: < SumOptions > :: try_from ( args) ?;
120+ let SumArgs { array, .. } = args. try_into ( ) ?;
101121 Stat :: Sum
102122 . dtype ( array. dtype ( ) )
103123 . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype: {}" , array. dtype( ) ) )
@@ -136,11 +156,14 @@ impl<V: VTable + SumKernel> SumKernelAdapter<V> {
136156
137157impl < V : VTable + SumKernel > Kernel for SumKernelAdapter < V > {
138158 fn invoke ( & self , args : & InvocationArgs ) -> VortexResult < Option < Output > > {
139- let UnaryArgs { array, options } = UnaryArgs :: < SumOptions > :: try_from ( args) ?;
159+ let SumArgs {
160+ array,
161+ accumulator : initial_value,
162+ } = args. try_into ( ) ?;
140163 let Some ( array) = array. as_opt :: < V > ( ) else {
141164 return Ok ( None ) ;
142165 } ;
143- Ok ( Some ( V :: sum ( & self . 0 , array, & options . initial_value ) ?. into ( ) ) )
166+ Ok ( Some ( V :: sum ( & self . 0 , array, initial_value) ?. into ( ) ) )
144167 }
145168}
146169
@@ -161,10 +184,8 @@ pub fn sum_impl(
161184
162185 // Try to find a sum kernel
163186 let args = InvocationArgs {
164- inputs : & [ array. into ( ) ] ,
165- options : & SumOptions {
166- initial_value : initial_value. clone ( ) ,
167- } ,
187+ inputs : & [ array. into ( ) , initial_value. into ( ) ] ,
188+ options : & ( ) ,
168189 } ;
169190 for kernel in kernels {
170191 if let Some ( output) = kernel. invoke ( & args) ? {
@@ -184,7 +205,7 @@ pub fn sum_impl(
184205 array. encoding_id( )
185206 ) ;
186207 }
187- sum_with_initial ( array. to_canonical ( ) . as_ref ( ) , initial_value. clone ( ) )
208+ sum_with_initial ( array. to_canonical ( ) . as_ref ( ) , initial_value)
188209}
189210
190211#[ cfg( test) ]
0 commit comments