@@ -155,8 +155,8 @@ mod tests {
155155 use vortex_array:: compute:: { Operator as ComputeOp , compare} ;
156156 use vortex_array:: expr:: session:: ExprSession ;
157157 use vortex_array:: expr:: transform:: ExprOptimizer ;
158- use vortex_array:: expr:: { gt, lit, root} ;
159- use vortex_array:: { Array , ArraySession , IntoArray , ToCanonical } ;
158+ use vortex_array:: expr:: { Binary , Literal , Root , gt, lit, root} ;
159+ use vortex_array:: { Array , ArraySession , IntoArray , ToCanonical , assert_arrays_eq } ;
160160
161161 use super :: * ;
162162 use crate :: alp_encode;
@@ -202,9 +202,18 @@ mod tests {
202202 let encoded_child = optimized_expr. child ( ) . to_primitive ( ) ;
203203 assert_eq ! ( encoded_child. as_slice:: <i32 >( ) , vec![ 1234 ; 100 ] ) ;
204204
205- // Verify the expression is comparing against encoded value (1.0 * 10^3 = 1000)
206- // The expression should be: $ > 1000 (since 1.0 encodes to 1000)
207- println ! ( "Optimized expression: {:?}" , optimized_expr. expr( ) ) ;
205+ // Verify the expression structure
206+ let binary_view = optimized_expr. expr ( ) . as_ :: < Binary > ( ) ;
207+ assert ! ( binary_view. lhs( ) . is:: <Root >( ) , "Left side should be root()" ) ;
208+ assert ! (
209+ binary_view. rhs( ) . is:: <Literal >( ) ,
210+ "Right side should be a literal"
211+ ) ;
212+ assert_eq ! (
213+ binary_view. operator( ) ,
214+ vortex_array:: expr:: Operator :: Gt ,
215+ "Operator should be Gt (1.0 encodes exactly to 1000, so the operator remains unchanged)"
216+ ) ;
208217
209218 // Verify correctness by comparing with the eager comparison kernel
210219 let expected = compare (
@@ -215,10 +224,8 @@ mod tests {
215224 . unwrap ( ) ;
216225 let actual = optimized. to_canonical ( ) . into_array ( ) ;
217226
218- assert_eq ! ( actual. len( ) , expected. len( ) ) ;
219- for i in 0 ..actual. len ( ) {
220- assert_eq ! ( actual. scalar_at( i) , expected. scalar_at( i) ) ;
221- }
227+ // Use assert_arrays_eq to validate the canonical form
228+ assert_arrays_eq ! ( actual. clone( ) , expected. clone( ) ) ;
222229
223230 // Result should be all true (1.234 > 1.0)
224231 for i in 0 ..actual. len ( ) {
@@ -240,6 +247,29 @@ mod tests {
240247 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
241248 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
242249
250+ // Verify the pushdown happened: should be ExprArray wrapping encoded integers
251+ let optimized_expr = optimized. as_ :: < ExprVTable > ( ) ;
252+ assert ! (
253+ optimized_expr
254+ . child( )
255+ . is:: <vortex_array:: arrays:: PrimitiveVTable >( ) ,
256+ "Pushdown failed: child is not PrimitiveArray, it's {:?}" ,
257+ optimized_expr. child( ) . encoding( ) . id( )
258+ ) ;
259+
260+ // Verify the expression structure
261+ let binary_view = optimized_expr. expr ( ) . as_ :: < Binary > ( ) ;
262+ assert ! ( binary_view. lhs( ) . is:: <Root >( ) , "Left side should be root()" ) ;
263+ assert ! (
264+ binary_view. rhs( ) . is:: <Literal >( ) ,
265+ "Right side should be a literal"
266+ ) ;
267+ assert_eq ! (
268+ binary_view. operator( ) ,
269+ vortex_array:: expr:: Operator :: Eq ,
270+ "Operator should be Eq"
271+ ) ;
272+
243273 // Verify correctness matches the eager comparison
244274 let expected = compare (
245275 alp. as_ref ( ) ,
@@ -249,10 +279,8 @@ mod tests {
249279 . unwrap ( ) ;
250280 let actual = optimized. to_canonical ( ) . into_array ( ) ;
251281
252- assert_eq ! ( actual. len( ) , expected. len( ) ) ;
253- for i in 0 ..actual. len ( ) {
254- assert_eq ! ( actual. scalar_at( i) , expected. scalar_at( i) ) ;
255- }
282+ // Use assert_arrays_eq to validate the canonical form
283+ assert_arrays_eq ! ( actual. clone( ) , expected. clone( ) ) ;
256284 }
257285
258286 #[ test]
@@ -282,12 +310,24 @@ mod tests {
282310 optimized. encoding( ) . id( )
283311 ) ;
284312
285- // Should return constant false (value can't be equal to any encoded value)
313+ // Downcast to ConstantArray and verify structure
314+ let constant_array = optimized. as_ :: < vortex_array:: arrays:: ConstantVTable > ( ) ;
315+ assert_eq ! ( constant_array. len( ) , 100 ) ;
316+ let false_scalar: Scalar = false . into ( ) ;
317+ assert_eq ! ( * constant_array. scalar( ) , false_scalar) ;
318+
319+ // Verify correctness matches the eager comparison
320+ #[ allow( clippy:: excessive_precision) ]
321+ let expected = compare (
322+ alp. as_ref ( ) ,
323+ ConstantArray :: new ( 1.234444f32 , 100 ) . as_ref ( ) ,
324+ ComputeOp :: Eq ,
325+ )
326+ . unwrap ( ) ;
286327 let actual = optimized. to_canonical ( ) . into_array ( ) ;
287- assert_eq ! ( actual. len( ) , 100 ) ;
288- for i in 0 ..actual. len ( ) {
289- assert_eq ! ( actual. scalar_at( i) , false . into( ) ) ;
290- }
328+
329+ // Use assert_arrays_eq to validate the canonical form
330+ assert_arrays_eq ! ( actual. clone( ) , expected. clone( ) ) ;
291331 }
292332
293333 #[ test]
@@ -299,20 +339,38 @@ mod tests {
299339 assert ! ( alp. patches( ) . is_some( ) ) ;
300340
301341 let expr = gt ( root ( ) , lit ( 1.0f32 ) ) ;
302- let expr_array = ExprArray :: new_infer_dtype ( alp. into_array ( ) , expr. clone ( ) ) . unwrap ( ) ;
342+ let expr_array =
343+ ExprArray :: new_infer_dtype ( alp. clone ( ) . into_array ( ) , expr. clone ( ) ) . unwrap ( ) ;
303344
304345 let session = ArraySession :: default ( ) ;
305346 crate :: register_alp_rules ( & session) ;
306347 let expr_session = ExprSession :: default ( ) ;
307348 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
308349 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
309350
310- // Optimization should not apply - expression should be unchanged
351+ // Optimization should not apply - child should still be ALPArray
311352 let optimized_expr = optimized. as_ :: < ExprVTable > ( ) ;
353+ assert ! (
354+ optimized_expr. child( ) . is:: <ALPVTable >( ) ,
355+ "When patches exist, pushdown should not apply - child should still be ALPArray"
356+ ) ;
312357 assert_eq ! ( optimized_expr. expr( ) , & expr) ;
358+
359+ // Verify correctness still holds even without pushdown
360+ let expected = compare (
361+ alp. as_ref ( ) ,
362+ ConstantArray :: new ( 1.0f32 , 5 ) . as_ref ( ) ,
363+ ComputeOp :: Gt ,
364+ )
365+ . unwrap ( ) ;
366+ let actual = optimized. to_canonical ( ) . into_array ( ) ;
367+
368+ // Use assert_arrays_eq to validate the canonical form
369+ assert_arrays_eq ! ( actual. clone( ) , expected. clone( ) ) ;
313370 }
314371
315372 #[ test]
373+ #[ allow( clippy:: use_debug) ]
316374 fn test_alp_pushdown_all_operators ( ) {
317375 let array = PrimitiveArray :: from_iter ( [ 0.0605f32 ; 10 ] ) ;
318376 let alp = alp_encode ( & array, None ) . unwrap ( ) ;
@@ -372,20 +430,43 @@ mod tests {
372430 opt_expr. expr( )
373431 ) ;
374432 pushdown_count += 1 ;
433+
434+ // Verify the expression structure for ExprArray optimizations
435+ let binary_view = opt_expr. expr ( ) . as_ :: < Binary > ( ) ;
436+ assert ! (
437+ binary_view. lhs( ) . is:: <Root >( ) ,
438+ "Left side should be root() for operator {:?}" ,
439+ compute_op
440+ ) ;
441+ assert ! (
442+ binary_view. rhs( ) . is:: <Literal >( ) ,
443+ "Right side should be a literal for operator {:?}" ,
444+ compute_op
445+ ) ;
375446 } else {
376447 println ! (
377448 "✗ Operator {:?}: Still ExprArray but child is {:?}" ,
378449 compute_op,
379450 opt_expr. child( ) . encoding( ) . id( )
380451 ) ;
381452 }
382- } else {
453+ } else if optimized . is :: < vortex_array :: arrays :: ConstantVTable > ( ) {
383454 println ! (
384455 "✓ Operator {:?}: Optimized to {:?} (constant result)" ,
385456 compute_op,
386457 optimized. encoding( ) . id( )
387458 ) ;
388459 pushdown_count += 1 ;
460+
461+ // Verify ConstantArray structure
462+ let constant_array = optimized. as_ :: < vortex_array:: arrays:: ConstantVTable > ( ) ;
463+ assert_eq ! ( constant_array. len( ) , 10 ) ;
464+ } else {
465+ println ! (
466+ "✗ Operator {:?}: Optimized to unexpected type {:?}" ,
467+ compute_op,
468+ optimized. encoding( ) . id( )
469+ ) ;
389470 }
390471
391472 // Verify correctness matches the eager comparison kernel
@@ -397,21 +478,8 @@ mod tests {
397478 . unwrap ( ) ;
398479 let actual = optimized. to_canonical ( ) . into_array ( ) ;
399480
400- assert_eq ! (
401- actual. len( ) ,
402- expected. len( ) ,
403- "Failed for operator {:?}" ,
404- compute_op
405- ) ;
406- for i in 0 ..actual. len ( ) {
407- assert_eq ! (
408- actual. scalar_at( i) ,
409- expected. scalar_at( i) ,
410- "Mismatch at index {} for operator {:?}" ,
411- i,
412- compute_op
413- ) ;
414- }
481+ // Use assert_arrays_eq to validate the canonical form
482+ assert_arrays_eq ! ( actual. clone( ) , expected. clone( ) ) ;
415483 }
416484
417485 // Verify that all operators were optimized
0 commit comments