Skip to content

Commit 30f0907

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent f7025c4 commit 30f0907

File tree

1 file changed

+104
-36
lines changed

1 file changed

+104
-36
lines changed

encodings/alp/src/alp/compute/expr_pushdown.rs

Lines changed: 104 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ mod tests {
155155
use vortex_array::compute::{Operator as ComputeOp, compare};
156156
use vortex_array::expr::session::ExprSession;
157157
use vortex_array::expr::transform::ExprOptimizer;
158-
use vortex_array::expr::{gt, lit, root};
159-
use vortex_array::{Array, ArraySession, IntoArray, ToCanonical};
158+
use vortex_array::expr::{Binary, Literal, Root, gt, lit, root};
159+
use vortex_array::{Array, ArraySession, IntoArray, ToCanonical, assert_arrays_eq};
160160

161161
use super::*;
162162
use crate::alp_encode;
@@ -202,9 +202,18 @@ mod tests {
202202
let encoded_child = optimized_expr.child().to_primitive();
203203
assert_eq!(encoded_child.as_slice::<i32>(), vec![1234; 100]);
204204

205-
// Verify the expression is comparing against encoded value (1.0 * 10^3 = 1000)
206-
// The expression should be: $ > 1000 (since 1.0 encodes to 1000)
207-
println!("Optimized expression: {:?}", optimized_expr.expr());
205+
// Verify the expression structure
206+
let binary_view = optimized_expr.expr().as_::<Binary>();
207+
assert!(binary_view.lhs().is::<Root>(), "Left side should be root()");
208+
assert!(
209+
binary_view.rhs().is::<Literal>(),
210+
"Right side should be a literal"
211+
);
212+
assert_eq!(
213+
binary_view.operator(),
214+
vortex_array::expr::Operator::Gt,
215+
"Operator should be Gt (1.0 encodes exactly to 1000, so the operator remains unchanged)"
216+
);
208217

209218
// Verify correctness by comparing with the eager comparison kernel
210219
let expected = compare(
@@ -215,10 +224,8 @@ mod tests {
215224
.unwrap();
216225
let actual = optimized.to_canonical().into_array();
217226

218-
assert_eq!(actual.len(), expected.len());
219-
for i in 0..actual.len() {
220-
assert_eq!(actual.scalar_at(i), expected.scalar_at(i));
221-
}
227+
// Use assert_arrays_eq to validate the canonical form
228+
assert_arrays_eq!(actual.clone(), expected.clone());
222229

223230
// Result should be all true (1.234 > 1.0)
224231
for i in 0..actual.len() {
@@ -240,6 +247,29 @@ mod tests {
240247
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
241248
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
242249

250+
// Verify the pushdown happened: should be ExprArray wrapping encoded integers
251+
let optimized_expr = optimized.as_::<ExprVTable>();
252+
assert!(
253+
optimized_expr
254+
.child()
255+
.is::<vortex_array::arrays::PrimitiveVTable>(),
256+
"Pushdown failed: child is not PrimitiveArray, it's {:?}",
257+
optimized_expr.child().encoding().id()
258+
);
259+
260+
// Verify the expression structure
261+
let binary_view = optimized_expr.expr().as_::<Binary>();
262+
assert!(binary_view.lhs().is::<Root>(), "Left side should be root()");
263+
assert!(
264+
binary_view.rhs().is::<Literal>(),
265+
"Right side should be a literal"
266+
);
267+
assert_eq!(
268+
binary_view.operator(),
269+
vortex_array::expr::Operator::Eq,
270+
"Operator should be Eq"
271+
);
272+
243273
// Verify correctness matches the eager comparison
244274
let expected = compare(
245275
alp.as_ref(),
@@ -249,10 +279,8 @@ mod tests {
249279
.unwrap();
250280
let actual = optimized.to_canonical().into_array();
251281

252-
assert_eq!(actual.len(), expected.len());
253-
for i in 0..actual.len() {
254-
assert_eq!(actual.scalar_at(i), expected.scalar_at(i));
255-
}
282+
// Use assert_arrays_eq to validate the canonical form
283+
assert_arrays_eq!(actual.clone(), expected.clone());
256284
}
257285

258286
#[test]
@@ -282,12 +310,24 @@ mod tests {
282310
optimized.encoding().id()
283311
);
284312

285-
// Should return constant false (value can't be equal to any encoded value)
313+
// Downcast to ConstantArray and verify structure
314+
let constant_array = optimized.as_::<vortex_array::arrays::ConstantVTable>();
315+
assert_eq!(constant_array.len(), 100);
316+
let false_scalar: Scalar = false.into();
317+
assert_eq!(*constant_array.scalar(), false_scalar);
318+
319+
// Verify correctness matches the eager comparison
320+
#[allow(clippy::excessive_precision)]
321+
let expected = compare(
322+
alp.as_ref(),
323+
ConstantArray::new(1.234444f32, 100).as_ref(),
324+
ComputeOp::Eq,
325+
)
326+
.unwrap();
286327
let actual = optimized.to_canonical().into_array();
287-
assert_eq!(actual.len(), 100);
288-
for i in 0..actual.len() {
289-
assert_eq!(actual.scalar_at(i), false.into());
290-
}
328+
329+
// Use assert_arrays_eq to validate the canonical form
330+
assert_arrays_eq!(actual.clone(), expected.clone());
291331
}
292332

293333
#[test]
@@ -299,20 +339,38 @@ mod tests {
299339
assert!(alp.patches().is_some());
300340

301341
let expr = gt(root(), lit(1.0f32));
302-
let expr_array = ExprArray::new_infer_dtype(alp.into_array(), expr.clone()).unwrap();
342+
let expr_array =
343+
ExprArray::new_infer_dtype(alp.clone().into_array(), expr.clone()).unwrap();
303344

304345
let session = ArraySession::default();
305346
crate::register_alp_rules(&session);
306347
let expr_session = ExprSession::default();
307348
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
308349
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
309350

310-
// Optimization should not apply - expression should be unchanged
351+
// Optimization should not apply - child should still be ALPArray
311352
let optimized_expr = optimized.as_::<ExprVTable>();
353+
assert!(
354+
optimized_expr.child().is::<ALPVTable>(),
355+
"When patches exist, pushdown should not apply - child should still be ALPArray"
356+
);
312357
assert_eq!(optimized_expr.expr(), &expr);
358+
359+
// Verify correctness still holds even without pushdown
360+
let expected = compare(
361+
alp.as_ref(),
362+
ConstantArray::new(1.0f32, 5).as_ref(),
363+
ComputeOp::Gt,
364+
)
365+
.unwrap();
366+
let actual = optimized.to_canonical().into_array();
367+
368+
// Use assert_arrays_eq to validate the canonical form
369+
assert_arrays_eq!(actual.clone(), expected.clone());
313370
}
314371

315372
#[test]
373+
#[allow(clippy::use_debug)]
316374
fn test_alp_pushdown_all_operators() {
317375
let array = PrimitiveArray::from_iter([0.0605f32; 10]);
318376
let alp = alp_encode(&array, None).unwrap();
@@ -372,20 +430,43 @@ mod tests {
372430
opt_expr.expr()
373431
);
374432
pushdown_count += 1;
433+
434+
// Verify the expression structure for ExprArray optimizations
435+
let binary_view = opt_expr.expr().as_::<Binary>();
436+
assert!(
437+
binary_view.lhs().is::<Root>(),
438+
"Left side should be root() for operator {:?}",
439+
compute_op
440+
);
441+
assert!(
442+
binary_view.rhs().is::<Literal>(),
443+
"Right side should be a literal for operator {:?}",
444+
compute_op
445+
);
375446
} else {
376447
println!(
377448
"✗ Operator {:?}: Still ExprArray but child is {:?}",
378449
compute_op,
379450
opt_expr.child().encoding().id()
380451
);
381452
}
382-
} else {
453+
} else if optimized.is::<vortex_array::arrays::ConstantVTable>() {
383454
println!(
384455
"✓ Operator {:?}: Optimized to {:?} (constant result)",
385456
compute_op,
386457
optimized.encoding().id()
387458
);
388459
pushdown_count += 1;
460+
461+
// Verify ConstantArray structure
462+
let constant_array = optimized.as_::<vortex_array::arrays::ConstantVTable>();
463+
assert_eq!(constant_array.len(), 10);
464+
} else {
465+
println!(
466+
"✗ Operator {:?}: Optimized to unexpected type {:?}",
467+
compute_op,
468+
optimized.encoding().id()
469+
);
389470
}
390471

391472
// Verify correctness matches the eager comparison kernel
@@ -397,21 +478,8 @@ mod tests {
397478
.unwrap();
398479
let actual = optimized.to_canonical().into_array();
399480

400-
assert_eq!(
401-
actual.len(),
402-
expected.len(),
403-
"Failed for operator {:?}",
404-
compute_op
405-
);
406-
for i in 0..actual.len() {
407-
assert_eq!(
408-
actual.scalar_at(i),
409-
expected.scalar_at(i),
410-
"Mismatch at index {} for operator {:?}",
411-
i,
412-
compute_op
413-
);
414-
}
481+
// Use assert_arrays_eq to validate the canonical form
482+
assert_arrays_eq!(actual.clone(), expected.clone());
415483
}
416484

417485
// Verify that all operators were optimized

0 commit comments

Comments
 (0)