Skip to content

Commit 903bca0

Browse files
committed
vector batch execute all the way down
Signed-off-by: Alexander Droste <[email protected]>
1 parent 4054cf8 commit 903bca0

File tree

5 files changed

+342
-190
lines changed

5 files changed

+342
-190
lines changed

encodings/alp/benches/alp_compress.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_alp::ALPFloat;
1111
use vortex_alp::ALPRDFloat;
1212
use vortex_alp::RDEncoder;
1313
use vortex_alp::alp_encode;
14-
use vortex_alp::decompress;
14+
use vortex_alp::decompress_into_array;
1515
use vortex_array::arrays::PrimitiveArray;
1616
use vortex_array::compute::warm_up_vtables;
1717
use vortex_array::validity::Validity;
@@ -97,7 +97,7 @@ fn decompress_alp<T: ALPFloat + NativePType>(bencher: Bencher, args: (usize, f64
9797
)
9898
.unwrap()
9999
})
100-
.bench_values(decompress);
100+
.bench_values(decompress_into_array);
101101
}
102102

103103
#[divan::bench(types = [f32, f64], args = [10_000, 100_000])]

encodings/alp/src/alp/array.rs

Lines changed: 234 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ use vortex_error::VortexResult;
4242
use vortex_error::vortex_bail;
4343
use vortex_error::vortex_ensure;
4444
use vortex_vector::Vector;
45-
use vortex_vector::VectorMutOps;
4645

4746
use crate::ALPFloat;
4847
use crate::alp::Exponents;
4948
use crate::alp::alp_encode;
50-
use crate::alp::decompress::decompress;
51-
use crate::alp::decompress::decompress_to_pvector;
49+
use crate::alp::decompress::decompress_into_array;
50+
use crate::alp::decompress::decompress_into_vector;
5251
use crate::match_each_alp_float_ptype;
5352

5453
vtable!(ALP);
@@ -142,9 +141,28 @@ impl VTable for ALPVTable {
142141
)
143142
}
144143

145-
fn execute(array: &ALPArray, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
144+
fn execute(array: &ALPArray, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> {
145+
let encoded_vector = array.encoded().execute_batch(ctx)?;
146+
147+
let patches_vectors = if let Some(patches) = array.patches() {
148+
Some((
149+
patches.indices().execute_batch(ctx)?,
150+
patches.values().execute_batch(ctx)?,
151+
patches
152+
.chunk_offsets()
153+
.as_ref()
154+
.map(|co| co.execute_batch(ctx))
155+
.transpose()?,
156+
))
157+
} else {
158+
None
159+
};
160+
161+
let patches_offset = array.patches().map(|p| p.offset()).unwrap_or(0);
162+
let exponents = array.exponents();
163+
146164
match_each_alp_float_ptype!(array.dtype().as_ptype(), |T| {
147-
Ok(decompress_to_pvector::<T>(array.clone()).freeze().into())
165+
decompress_into_vector::<T>(encoded_vector, exponents, patches_vectors, patches_offset)
148166
})
149167
}
150168
}
@@ -404,7 +422,7 @@ impl BaseArrayVTable<ALPVTable> for ALPVTable {
404422

405423
impl CanonicalVTable<ALPVTable> for ALPVTable {
406424
fn canonicalize(array: &ALPArray) -> Canonical {
407-
Canonical::Primitive(decompress(array.clone()))
425+
Canonical::Primitive(decompress_into_array(array.clone()))
408426
}
409427
}
410428

@@ -432,3 +450,213 @@ impl VisitorVTable<ALPVTable> for ALPVTable {
432450
}
433451
}
434452
}
453+
454+
#[cfg(test)]
455+
mod tests {
456+
use std::f64::consts::PI;
457+
458+
use rstest::rstest;
459+
use vortex_array::arrays::PrimitiveArray;
460+
use vortex_array::vtable::ValidityHelper;
461+
use vortex_dtype::PTypeDowncast;
462+
use vortex_vector::VectorOps;
463+
464+
use super::*;
465+
466+
#[rstest]
467+
#[case(0)]
468+
#[case(1)]
469+
#[case(100)]
470+
#[case(1023)]
471+
#[case(1024)]
472+
#[case(1025)]
473+
#[case(2047)]
474+
#[case(2048)]
475+
#[case(2049)]
476+
fn test_execute_f32(#[case] size: usize) {
477+
let values = PrimitiveArray::from_iter((0..size).map(|i| i as f32));
478+
let encoded = alp_encode(&values, None).unwrap();
479+
480+
let result_vector = encoded.to_array().execute().unwrap();
481+
// Compare against the traditional array-based decompress path
482+
let expected = decompress_into_array(encoded);
483+
484+
assert_eq!(result_vector.len(), size);
485+
486+
let result_primitive = result_vector.into_primitive().into_f32();
487+
assert_eq!(result_primitive.as_ref(), expected.as_slice::<f32>());
488+
}
489+
490+
#[rstest]
491+
#[case(0)]
492+
#[case(1)]
493+
#[case(100)]
494+
#[case(1023)]
495+
#[case(1024)]
496+
#[case(1025)]
497+
#[case(2047)]
498+
#[case(2048)]
499+
#[case(2049)]
500+
fn test_execute_f64(#[case] size: usize) {
501+
let values = PrimitiveArray::from_iter((0..size).map(|i| i as f64));
502+
let encoded = alp_encode(&values, None).unwrap();
503+
504+
let result_vector = encoded.to_array().execute().unwrap();
505+
// Compare against the traditional array-based decompress path
506+
let expected = decompress_into_array(encoded);
507+
508+
assert_eq!(result_vector.len(), size);
509+
510+
let result_primitive = result_vector.into_primitive().into_f64();
511+
assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
512+
}
513+
514+
#[rstest]
515+
#[case(100)]
516+
#[case(1023)]
517+
#[case(1024)]
518+
#[case(1025)]
519+
#[case(2047)]
520+
#[case(2048)]
521+
#[case(2049)]
522+
fn test_execute_with_patches(#[case] size: usize) {
523+
let values: Vec<f64> = (0..size)
524+
.map(|i| match i % 4 {
525+
0..=2 => 1.0,
526+
_ => PI,
527+
})
528+
.collect();
529+
530+
let array = PrimitiveArray::from_iter(values);
531+
let encoded = alp_encode(&array, None).unwrap();
532+
assert!(encoded.patches().unwrap().array_len() > 0);
533+
534+
let result_vector = encoded.to_array().execute().unwrap();
535+
// Compare against the traditional array-based decompress path
536+
let expected = decompress_into_array(encoded);
537+
538+
assert_eq!(result_vector.len(), size);
539+
540+
let result_primitive = result_vector.into_primitive().into_f64();
541+
assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
542+
}
543+
544+
#[rstest]
545+
#[case(0)]
546+
#[case(1)]
547+
#[case(100)]
548+
#[case(1023)]
549+
#[case(1024)]
550+
#[case(1025)]
551+
#[case(2047)]
552+
#[case(2048)]
553+
#[case(2049)]
554+
fn test_execute_with_validity(#[case] size: usize) {
555+
let values: Vec<Option<f32>> = (0..size)
556+
.map(|i| if i % 2 == 1 { None } else { Some(1.0) })
557+
.collect();
558+
559+
let array = PrimitiveArray::from_option_iter(values);
560+
let encoded = alp_encode(&array, None).unwrap();
561+
562+
let result_vector = encoded.to_array().execute().unwrap();
563+
// Compare against the traditional array-based decompress path
564+
let expected = decompress_into_array(encoded);
565+
566+
assert_eq!(result_vector.len(), size);
567+
568+
let result_primitive = result_vector.into_primitive().into_f32();
569+
assert_eq!(result_primitive.as_ref(), expected.as_slice::<f32>());
570+
571+
// Test validity masks match
572+
for idx in 0..size {
573+
assert_eq!(
574+
result_primitive.validity().value(idx),
575+
expected.validity().is_valid(idx)
576+
);
577+
}
578+
}
579+
580+
#[rstest]
581+
#[case(100)]
582+
#[case(1023)]
583+
#[case(1024)]
584+
#[case(1025)]
585+
#[case(2047)]
586+
#[case(2048)]
587+
#[case(2049)]
588+
fn test_execute_with_patches_and_validity(#[case] size: usize) {
589+
let values: Vec<Option<f64>> = (0..size)
590+
.map(|idx| match idx % 3 {
591+
0 => Some(1.0),
592+
1 => None,
593+
_ => Some(PI),
594+
})
595+
.collect();
596+
597+
let array = PrimitiveArray::from_option_iter(values);
598+
let encoded = alp_encode(&array, None).unwrap();
599+
assert!(encoded.patches().unwrap().array_len() > 0);
600+
601+
let result_vector = encoded.to_array().execute().unwrap();
602+
// Compare against the traditional array-based decompress path
603+
let expected = decompress_into_array(encoded);
604+
605+
assert_eq!(result_vector.len(), size);
606+
607+
let result_primitive = result_vector.into_primitive().into_f64();
608+
assert_eq!(result_primitive.as_ref(), expected.as_slice::<f64>());
609+
610+
// Test validity masks match
611+
for idx in 0..size {
612+
assert_eq!(
613+
result_primitive.validity().value(idx),
614+
expected.validity().is_valid(idx)
615+
);
616+
}
617+
}
618+
619+
#[rstest]
620+
#[case(500, 100)]
621+
#[case(1000, 200)]
622+
#[case(2048, 512)]
623+
fn test_execute_sliced_vector(#[case] size: usize, #[case] slice_start: usize) {
624+
let values: Vec<Option<f64>> = (0..size)
625+
.map(|i| {
626+
if i % 5 == 0 {
627+
None
628+
} else if i % 4 == 3 {
629+
Some(PI)
630+
} else {
631+
Some(1.0)
632+
}
633+
})
634+
.collect();
635+
636+
let array = PrimitiveArray::from_option_iter(values.clone());
637+
let encoded = alp_encode(&array, None).unwrap();
638+
639+
let slice_end = size - slice_start;
640+
let slice_len = slice_end - slice_start;
641+
let sliced_encoded = encoded.slice(slice_start..slice_end);
642+
643+
let result_vector = sliced_encoded.execute().unwrap();
644+
let result_primitive = result_vector.into_primitive().into_f64();
645+
646+
for idx in 0..slice_len {
647+
let expected_value = values[slice_start + idx];
648+
649+
let result_valid = result_primitive.validity().value(idx);
650+
assert_eq!(
651+
result_valid,
652+
expected_value.is_some(),
653+
"Validity mismatch at idx={idx}",
654+
);
655+
656+
if let Some(expected_val) = expected_value {
657+
let result_val = result_primitive.as_ref()[idx];
658+
assert_eq!(result_val, expected_val, "Value mismatch at idx={idx}",);
659+
}
660+
}
661+
}
662+
}

encodings/alp/src/alp/compress.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ mod tests {
139139
use vortex_dtype::NativePType;
140140

141141
use super::*;
142-
use crate::decompress;
142+
use crate::decompress_into_array;
143143

144144
#[test]
145145
fn test_compress() {
@@ -150,7 +150,7 @@ mod tests {
150150
assert_arrays_eq!(encoded.encoded(), expected_encoded);
151151
assert_eq!(encoded.exponents(), Exponents { e: 9, f: 6 });
152152

153-
let decoded = decompress(encoded);
153+
let decoded = decompress_into_array(encoded);
154154
assert_arrays_eq!(decoded, array);
155155
}
156156

@@ -163,7 +163,7 @@ mod tests {
163163
assert_arrays_eq!(encoded.encoded(), expected_encoded);
164164
assert_eq!(encoded.exponents(), Exponents { e: 9, f: 6 });
165165

166-
let decoded = decompress(encoded);
166+
let decoded = decompress_into_array(encoded);
167167
let expected = PrimitiveArray::from_option_iter(vec![None, Some(1.234f32), None]);
168168
assert_arrays_eq!(decoded, expected);
169169
}
@@ -179,7 +179,7 @@ mod tests {
179179
assert_arrays_eq!(encoded.encoded(), expected_encoded);
180180
assert_eq!(encoded.exponents(), Exponents { e: 16, f: 13 });
181181

182-
let decoded = decompress(encoded);
182+
let decoded = decompress_into_array(encoded);
183183
let expected_decoded = PrimitiveArray::new(values, Validity::NonNullable);
184184
assert_arrays_eq!(decoded, expected_decoded);
185185
}
@@ -196,7 +196,7 @@ mod tests {
196196
assert_arrays_eq!(encoded.encoded(), expected_encoded);
197197
assert_eq!(encoded.exponents(), Exponents { e: 16, f: 13 });
198198

199-
let decoded = decompress(encoded);
199+
let decoded = decompress_into_array(encoded);
200200
assert_arrays_eq!(decoded, array);
201201
}
202202

@@ -217,7 +217,7 @@ mod tests {
217217

218218
assert_arrays_eq!(encoded, array);
219219

220-
let _decoded = decompress(encoded);
220+
let _decoded = decompress_into_array(encoded);
221221
}
222222

223223
#[test]
@@ -444,7 +444,7 @@ mod tests {
444444
let encoded = alp_encode(&array, None).unwrap();
445445

446446
assert!(encoded.patches().is_none());
447-
let decoded = decompress(encoded);
447+
let decoded = decompress_into_array(encoded);
448448
assert_eq!(array.as_slice::<f32>(), decoded.as_slice::<f32>());
449449
}
450450

@@ -455,7 +455,7 @@ mod tests {
455455
let encoded = alp_encode(&array, None).unwrap();
456456

457457
assert!(encoded.patches().is_none());
458-
let decoded = decompress(encoded);
458+
let decoded = decompress_into_array(encoded);
459459
assert_eq!(array.as_slice::<f64>(), decoded.as_slice::<f64>());
460460
}
461461

@@ -472,7 +472,7 @@ mod tests {
472472
let encoded = alp_encode(&array, None).unwrap();
473473

474474
assert!(encoded.patches().is_some());
475-
let decoded = decompress(encoded);
475+
let decoded = decompress_into_array(encoded);
476476
assert_eq!(values.as_slice(), decoded.as_slice::<f32>());
477477
}
478478

@@ -493,7 +493,7 @@ mod tests {
493493
let encoded = alp_encode(&array, None).unwrap();
494494

495495
assert!(encoded.patches().is_some());
496-
let decoded = decompress(encoded);
496+
let decoded = decompress_into_array(encoded);
497497

498498
for idx in 0..size {
499499
let decoded_val = decoded.as_slice::<f64>()[idx];
@@ -520,7 +520,7 @@ mod tests {
520520

521521
let array = PrimitiveArray::from_option_iter(values);
522522
let encoded = alp_encode(&array, None).unwrap();
523-
let decoded = decompress(encoded);
523+
let decoded = decompress_into_array(encoded);
524524

525525
assert_arrays_eq!(decoded, array);
526526
}
@@ -540,7 +540,7 @@ mod tests {
540540

541541
let array = PrimitiveArray::new(Buffer::from(values), validity);
542542
let encoded = alp_encode(&array, None).unwrap();
543-
let decoded = decompress(encoded);
543+
let decoded = decompress_into_array(encoded);
544544

545545
assert_arrays_eq!(decoded, array);
546546
}

0 commit comments

Comments
 (0)