@@ -255,6 +255,13 @@ fn decompress_batch(
255255 unpack_10 ( bitpacked, bitpacked_output) ;
256256 for_decompress ( cast_u32_as_i32 ( bitpacked_output) , reference, for_decoded) ;
257257 alp_decompress ( for_decoded, exponents, alp_decoded) ;
258+
259+ // Cast f32 output to u32 for filtering.
260+ // SAFETY: f32 and u32 have the same size and alignment.
261+ let alp_as_u32 = unsafe {
262+ std:: slice:: from_raw_parts_mut ( alp_decoded. as_mut_ptr ( ) as * mut u32 , alp_decoded. len ( ) )
263+ } ;
264+ let _kept = filter_scalar ( alp_as_u32) ;
258265}
259266
260267/// In-place batch decompression that reuses a single buffer for all stages.
@@ -280,6 +287,12 @@ fn decompress_in_place_batch(
280287
281288 // Stage 3: ALP decode in-place (transmute i32 → f32).
282289 f32:: decode_slice_inplace ( buffer_i32, exponents) ;
290+
291+ // Cast f32 output to u32 for filtering.
292+ // SAFETY: f32 and u32 have the same size and alignment.
293+ let output_as_u32 =
294+ unsafe { std:: slice:: from_raw_parts_mut ( output. as_mut_ptr ( ) as * mut u32 , output. len ( ) ) } ;
295+ let _kept = filter_scalar ( output_as_u32) ;
283296}
284297
285298////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -308,7 +321,7 @@ fn decompress_pipeline(
308321 let for_chunk = & mut for_buffer[ ..N ] ;
309322
310323 let mut input_offset = 0 ;
311- let mut output_offset = 0 ;
324+ let mut output_write_offset = 0 ; // Track where to write filtered output.
312325
313326 // Process each 1024-element chunk.
314327 while input_offset < bitpacked. len ( ) {
@@ -328,18 +341,27 @@ fn decompress_pipeline(
328341 }
329342 }
330343
331- // Stage 3: ALP decompression.
344+ // Stage 3: ALP decompression directly into output buffer.
345+ // We decompress into the output buffer starting at output_write_offset.
332346 // SAFETY: Buffer sizes and output bounds are verified.
333347 unsafe {
334- let output_chunk = output. get_unchecked_mut ( output_offset..output_offset + N ) ;
348+ let output_chunk =
349+ output. get_unchecked_mut ( output_write_offset..output_write_offset + N ) ;
335350 for i in 0 ..N {
336351 let for_decoded = * for_chunk. get_unchecked ( i) ;
337352 * output_chunk. get_unchecked_mut ( i) = f32:: decode_single ( for_decoded, exponents) ;
338353 }
339354 }
340355
356+ // Stage 4: Filter the chunk in the output buffer.
357+ // Note: filter_scalar modifies the data in-place, compacting it.
358+ let output_chunk =
359+ unsafe { output. get_unchecked_mut ( output_write_offset..output_write_offset + N ) } ;
360+ let kept_count = filter_scalar ( output_chunk) ;
361+
362+ // The filtered data is now compacted at output_write_offset.
363+ output_write_offset += kept_count;
341364 input_offset += S ;
342- output_offset += N ;
343365 }
344366}
345367
@@ -368,7 +390,7 @@ fn decompress_pipeline_extra_copy(
368390 let alp_chunk = & mut alp_buffer[ ..N ] ;
369391
370392 let mut input_offset = 0 ;
371- let mut output_offset = 0 ;
393+ let mut output_write_offset = 0 ; // Track where to write filtered output.
372394
373395 // Process each 1024-element chunk.
374396 while input_offset < bitpacked. len ( ) {
@@ -397,13 +419,18 @@ fn decompress_pipeline_extra_copy(
397419 }
398420 }
399421
400- // Stage 4: Copy from intermediate ALP buffer to final output.
401- // SAFETY: Buffer sizes are verified to be N.
402- let output_chunk = unsafe { output. get_unchecked_mut ( output_offset..output_offset + N ) } ;
403- output_chunk. copy_from_slice ( alp_chunk) ;
422+ // Stage 4: Filter the intermediate ALP buffer.
423+ let kept_count = filter_scalar ( alp_chunk) ;
424+
425+ // Stage 5: Copy filtered data from intermediate ALP buffer to final output.
426+ // SAFETY: Buffer sizes are verified and kept_count <= N.
427+ let output_chunk = unsafe {
428+ output. get_unchecked_mut ( output_write_offset..output_write_offset + kept_count)
429+ } ;
430+ output_chunk. copy_from_slice ( & alp_chunk[ ..kept_count] ) ;
404431
432+ output_write_offset += kept_count;
405433 input_offset += S ;
406- output_offset += N ;
407434 }
408435}
409436
@@ -420,12 +447,13 @@ fn decompress_in_place_pipeline(
420447 debug_assert_eq ! ( output. len( ) , bitpacked. len( ) * T / W ) ;
421448
422449 let mut input_offset = 0 ;
423- let mut output_offset = 0 ;
450+ let mut output_write_offset = 0 ; // Track where to write filtered output.
424451
425452 while input_offset < bitpacked. len ( ) {
426453 // Get the current chunk of the output buffer to work on.
427454 // SAFETY: Output bounds are verified by debug_assert.
428- let output_chunk = unsafe { output. get_unchecked_mut ( output_offset..output_offset + N ) } ;
455+ let output_chunk =
456+ unsafe { output. get_unchecked_mut ( output_write_offset..output_write_offset + N ) } ;
429457
430458 // Reinterpret the output chunk as u32 for unpacking.
431459 // SAFETY: f32 and u32 have the same size and alignment.
@@ -457,11 +485,52 @@ fn decompress_in_place_pipeline(
457485 }
458486 }
459487
488+ // Stage 4: Filter the chunk in-place.
489+ let kept_count = filter_scalar ( output_chunk) ;
490+
491+ output_write_offset += kept_count;
460492 input_offset += S ;
461- output_offset += N ;
462493 }
463494}
464495
496+ ////////////////////////////////////////////////////////////////////////////////////////////////////
497+ // Filter Functions
498+ ////////////////////////////////////////////////////////////////////////////////////////////////////
499+
500+ // Hardcoded mask for now.
501+
502+ fn filter_scalar < T : Copy > ( data : & mut [ T ] ) -> usize {
503+ let len = data. len ( ) ;
504+ assert ! ( len. is_multiple_of( usize :: BITS as usize ) ) ;
505+
506+ let iters = len / 64 ;
507+
508+ let mut read_ptr = data. as_ptr ( ) ;
509+ let mut write_ptr = data. as_mut_ptr ( ) ;
510+ let initial_write_ptr = write_ptr;
511+
512+ for _ in 0 ..iters {
513+ let mut word: usize = std:: hint:: black_box ( 0xDEADBEEF ) ;
514+
515+ while word != 0 {
516+ let bit_pos = word. trailing_zeros ( ) ;
517+ word &= word - 1 ; // Clear the bit at `bit_pos`.
518+ let span = word. trailing_ones ( ) ;
519+ word >>= span;
520+
521+ unsafe {
522+ std:: ptr:: copy ( read_ptr. add ( bit_pos as usize ) , write_ptr, span as usize ) ;
523+ write_ptr = write_ptr. add ( span as usize ) ;
524+ }
525+ }
526+
527+ unsafe { read_ptr = read_ptr. add ( usize:: BITS as usize ) } ;
528+ }
529+
530+ // Return the number of elements kept.
531+ unsafe { write_ptr. offset_from ( initial_write_ptr) as usize }
532+ }
533+
465534////////////////////////////////////////////////////////////////////////////////////////////////////
466535// Bitpacking Functions
467536////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -650,22 +719,46 @@ fn verify(
650719/// Compare outputs from different decompression functions.
651720///
652721/// Ensures that all decompression strategies produce identical results.
653- fn compare_outputs ( function_name : & str , expected : & [ f32 ] , actual : & [ f32 ] ) {
654- assert_eq ! (
655- expected. len( ) ,
656- actual. len( ) ,
657- "{}: Output length mismatch: expected={}, actual={}" ,
658- function_name,
659- expected. len( ) ,
660- actual. len( )
661- ) ;
662-
663- for i in 0 ..expected. len ( ) {
664- assert_eq ! (
665- expected[ i] , actual[ i] ,
666- "{}: Output mismatch at index {}: expected={}, actual={}" ,
667- function_name, i, expected[ i] , actual[ i]
668- ) ;
722+ /// Filtering should produce the same results whether applied chunk-by-chunk
723+ /// or all at once. Both expected and actual should already be filtered.
724+ fn compare_outputs ( function_name : & str , expected : & [ f32 ] , actual : & [ f32 ] , expected_len : usize ) {
725+ // Both buffers should have the same allocated size.
726+ assert_eq ! ( actual. len( ) , expected. len( ) ) ;
727+
728+ // Only compare the filtered portion of the data.
729+ let expected_slice = & expected[ ..expected_len] ;
730+ let actual_slice = & actual[ ..expected_len] ;
731+
732+ for i in 0 ..expected_len {
733+ if expected_slice[ i] != actual_slice[ i] {
734+ // Debug output to understand the mismatch.
735+ eprintln ! (
736+ "Mismatch at index {}: expected={}, actual={}" ,
737+ i, expected_slice[ i] , actual_slice[ i]
738+ ) ;
739+ if i > 0 {
740+ eprintln ! (
741+ " Previous values: expected[{}]={}, actual[{}]={}" ,
742+ i - 1 ,
743+ expected_slice[ i - 1 ] ,
744+ i - 1 ,
745+ actual_slice[ i - 1 ]
746+ ) ;
747+ }
748+ if i + 1 < expected_len {
749+ eprintln ! (
750+ " Next values: expected[{}]={}, actual[{}]={}" ,
751+ i + 1 ,
752+ expected_slice[ i + 1 ] ,
753+ i + 1 ,
754+ actual_slice[ i + 1 ]
755+ ) ;
756+ }
757+ panic ! (
758+ "{}: Output mismatch at index {}: expected={}, actual={}" ,
759+ function_name, i, expected_slice[ i] , actual_slice[ i]
760+ ) ;
761+ }
669762 }
670763}
671764
@@ -767,7 +860,17 @@ mod correctness_verification {
767860 #[ divan:: bench( consts = VERIFICATION_SIZES ) ]
768861 fn verify_all_methods < const SIZE : usize > ( bencher : Bencher ) {
769862 bencher. bench_local ( || {
770- let ( input_data, mut buffers) = setup ( SIZE ) ;
863+ let ( mut input_data, mut buffers) = setup ( SIZE ) ;
864+
865+ // Create a filtered version of the original values for comparison.
866+ // SAFETY: f32 and u32 have the same size and alignment.
867+ let original_as_u32 = unsafe {
868+ std:: slice:: from_raw_parts_mut (
869+ input_data. original . as_mut_ptr ( ) as * mut u32 ,
870+ input_data. original . len ( ) ,
871+ )
872+ } ;
873+ let expected_filtered_len = filter_scalar ( original_as_u32) ;
771874
772875 // Run batch decompression (our reference implementation).
773876 decompress_batch (
@@ -780,12 +883,13 @@ mod correctness_verification {
780883 ) ;
781884
782885 // Verify batch decompression is correct.
886+ // Note: for_decoded is not filtered, but alp_decoded is filtered.
783887 verify (
784888 "batch" ,
785889 & buffers. for_decoded ,
786890 & buffers. alp_decoded ,
787891 & input_data. alp_encoded ,
788- & input_data. original ,
892+ & input_data. original , // This is now filtered.
789893 & input_data. patches ,
790894 ) ;
791895
@@ -798,7 +902,12 @@ mod correctness_verification {
798902 & mut buffers. for_decoded ,
799903 & mut buffers. pipeline_output ,
800904 ) ;
801- compare_outputs ( "pipeline" , & buffers. alp_decoded , & buffers. pipeline_output ) ;
905+ compare_outputs (
906+ "pipeline" ,
907+ & buffers. alp_decoded ,
908+ & buffers. pipeline_output ,
909+ expected_filtered_len,
910+ ) ;
802911
803912 // Run in-place batch decompression and compare with batch.
804913 decompress_in_place_batch (
@@ -811,6 +920,7 @@ mod correctness_verification {
811920 "in_place_batch" ,
812921 & buffers. alp_decoded ,
813922 & buffers. alp_decoded_inplace_batch ,
923+ expected_filtered_len,
814924 ) ;
815925
816926 // Run in-place pipeline decompression and compare with batch.
@@ -824,6 +934,7 @@ mod correctness_verification {
824934 "in_place_pipeline" ,
825935 & buffers. alp_decoded ,
826936 & buffers. alp_decoded_inplace_pipeline ,
937+ expected_filtered_len,
827938 ) ;
828939 } ) ;
829940 }
0 commit comments