@@ -108,7 +108,8 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());
108108
109109#[ cfg( test) ]
110110mod tests {
111- use vortex_dtype:: { DType , DecimalDType , Nullability , PType } ;
111+ use vortex_dtype:: Nullability :: Nullable ;
112+ use vortex_dtype:: { DType , DecimalDType , Nullability , PType , i256} ;
112113 use vortex_scalar:: { DecimalValue , Scalar } ;
113114
114115 use crate :: arrays:: ConstantArray ;
@@ -132,13 +133,10 @@ mod tests {
132133
133134 #[ test]
134135 fn test_sum_nullable_value ( ) {
135- let array = ConstantArray :: new (
136- Scalar :: null ( DType :: Primitive ( PType :: U32 , Nullability :: Nullable ) ) ,
137- 10 ,
138- )
139- . into_array ( ) ;
136+ let array = ConstantArray :: new ( Scalar :: null ( DType :: Primitive ( PType :: U32 , Nullable ) ) , 10 )
137+ . into_array ( ) ;
140138 let result = sum ( & array) . unwrap ( ) ;
141- assert ! ( result. is_null ( ) ) ;
139+ assert_eq ! ( result, Scalar :: primitive ( 0u64 , Nullable ) ) ;
142140 }
143141
144142 #[ test]
@@ -157,10 +155,9 @@ mod tests {
157155
158156 #[ test]
159157 fn test_sum_bool_null ( ) {
160- let array =
161- ConstantArray :: new ( Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) ) , 10 ) . into_array ( ) ;
158+ let array = ConstantArray :: new ( Scalar :: null ( DType :: Bool ( Nullable ) ) , 10 ) . into_array ( ) ;
162159 let result = sum ( & array) . unwrap ( ) ;
163- assert ! ( result. is_null ( ) ) ;
160+ assert_eq ! ( result, Scalar :: primitive ( 0u64 , Nullable ) ) ;
164161 }
165162
166163 #[ test]
@@ -180,22 +177,26 @@ mod tests {
180177
181178 assert_eq ! (
182179 result. as_decimal( ) . decimal_value( ) ,
183- Some ( DecimalValue :: I256 ( vortex_scalar :: i256:: from_i128( 500 ) ) )
180+ Some ( DecimalValue :: I256 ( i256:: from_i128( 500 ) ) )
184181 ) ;
185182 assert_eq ! ( result. dtype( ) , & Stat :: Sum . dtype( array. dtype( ) ) . unwrap( ) ) ;
186183 }
187184
188185 #[ test]
189186 fn test_sum_decimal_null ( ) {
190187 let decimal_dtype = DecimalDType :: new ( 10 , 2 ) ;
191- let array = ConstantArray :: new (
192- Scalar :: null ( DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) ) ,
193- 10 ,
194- )
195- . into_array ( ) ;
188+ let array = ConstantArray :: new ( Scalar :: null ( DType :: Decimal ( decimal_dtype, Nullable ) ) , 10 )
189+ . into_array ( ) ;
196190
197191 let result = sum ( & array) . unwrap ( ) ;
198- assert ! ( result. is_null( ) ) ;
192+ assert_eq ! (
193+ result,
194+ Scalar :: decimal(
195+ DecimalValue :: I256 ( i256:: ZERO ) ,
196+ DecimalDType :: new( 20 , 2 ) ,
197+ Nullable
198+ )
199+ ) ;
199200 }
200201
201202 #[ test]
@@ -214,9 +215,7 @@ mod tests {
214215 let result = sum ( & array) . unwrap ( ) ;
215216 assert_eq ! (
216217 result. as_decimal( ) . decimal_value( ) ,
217- Some ( DecimalValue :: I256 ( vortex_scalar:: i256:: from_i128(
218- 99_999_999_900
219- ) ) )
218+ Some ( DecimalValue :: I256 ( i256:: from_i128( 99_999_999_900 ) ) )
220219 ) ;
221220 }
222221}
0 commit comments