@@ -734,46 +734,59 @@ fn ktSingleThreaded(comptime Variant: type, view: *const MultiSliceView, total_l
734734 final_state .final (output );
735735}
736736
737- /// Context for a thread task that processes leaves using SIMD
738- fn LeafThreadContext (comptime Variant : type ) type {
737+ fn BatchResult (comptime Variant : type ) type {
738+ const cv_size = Variant .cv_size ;
739+ const leaves_per_batch = bytes_per_batch / chunk_size ;
740+ const max_cvs_size = leaves_per_batch * cv_size ;
741+
742+ return struct {
743+ batch_idx : usize ,
744+ cv_len : usize ,
745+ cvs : [max_cvs_size ]u8 ,
746+ };
747+ }
748+
749+ fn SelectLeafContext (comptime Variant : type ) type {
750+ const cv_size = Variant .cv_size ;
751+ const Result = BatchResult (Variant );
752+
739753 return struct {
740754 view : * const MultiSliceView ,
741- start_offset : usize , // Byte offset in view (after first chunk)
742- num_leaves : usize , // Number of leaves to process
743- output_cvs : [] align ( @alignOf ( u64 )) u8 , // Where to store CVs
755+ batch_idx : usize ,
756+ start_offset : usize ,
757+ num_leaves : usize ,
744758
745- fn process (ctx : @This ()) void {
746- const cv_size = Variant .cv_size ;
759+ fn process (ctx : @This ()) Result {
760+ var result : Result = .{
761+ .batch_idx = ctx .batch_idx ,
762+ .cv_len = ctx .num_leaves * cv_size ,
763+ .cvs = undefined ,
764+ };
747765
748- // Thread-local scratch buffer for copying data that spans slice boundaries
749766 var leaf_buffer : [bytes_per_batch ]u8 align (cache_line_size ) = undefined ;
750-
751767 var leaves_processed : usize = 0 ;
752768 var byte_offset = ctx .start_offset ;
753769 var cv_offset : usize = 0 ;
754-
755- // Process leaves in SIMD batches
756770 const simd_batch_bytes = optimal_vector_len * chunk_size ;
757771 while (leaves_processed + optimal_vector_len <= ctx .num_leaves ) {
758772 if (ctx .view .tryGetSlice (byte_offset , byte_offset + simd_batch_bytes )) | leaf_data | {
759773 var leaf_cvs : [optimal_vector_len * Variant .cv_size ]u8 = undefined ;
760774 processLeaves (Variant , optimal_vector_len , leaf_data , & leaf_cvs );
761- @memcpy (ctx . output_cvs [cv_offset .. ][0.. leaf_cvs .len ], & leaf_cvs );
775+ @memcpy (result . cvs [cv_offset .. ][0.. leaf_cvs .len ], & leaf_cvs );
762776 } else {
763777 ctx .view .copyRange (byte_offset , byte_offset + simd_batch_bytes , leaf_buffer [0.. simd_batch_bytes ]);
764778 var leaf_cvs : [optimal_vector_len * Variant .cv_size ]u8 = undefined ;
765779 processLeaves (Variant , optimal_vector_len , leaf_buffer [0.. simd_batch_bytes ], & leaf_cvs );
766- @memcpy (ctx . output_cvs [cv_offset .. ][0.. leaf_cvs .len ], & leaf_cvs );
780+ @memcpy (result . cvs [cv_offset .. ][0.. leaf_cvs .len ], & leaf_cvs );
767781 }
768782 leaves_processed += optimal_vector_len ;
769783 byte_offset += optimal_vector_len * chunk_size ;
770784 cv_offset += optimal_vector_len * cv_size ;
771785 }
772786
773- // Process remaining leaves one at a time (should be less than optimal_vector_len)
774787 while (leaves_processed < ctx .num_leaves ) {
775788 const leaf_end = byte_offset + chunk_size ;
776- var cv_buffer : [64 ]u8 = undefined ; // Max CV size is 64 bytes
789+ var cv_buffer : [64 ]u8 = undefined ;
777790
778791 if (ctx .view .tryGetSlice (byte_offset , leaf_end )) | leaf_data | {
779792 const cv_slice = MultiSliceView .init (leaf_data , &[_ ]u8 {}, &[_ ]u8 {});
@@ -783,22 +796,23 @@ fn LeafThreadContext(comptime Variant: type) type {
783796 const cv_slice = MultiSliceView .init (leaf_buffer [0.. chunk_size ], &[_ ]u8 {}, &[_ ]u8 {});
784797 Variant .turboShakeToBuffer (& cv_slice , 0x0B , cv_buffer [0.. cv_size ]);
785798 }
786- @memcpy (ctx . output_cvs [cv_offset .. ][0.. cv_size ], cv_buffer [0.. cv_size ]);
799+ @memcpy (result . cvs [cv_offset .. ][0.. cv_size ], cv_buffer [0.. cv_size ]);
787800
788801 leaves_processed += 1 ;
789802 byte_offset += chunk_size ;
790803 cv_offset += cv_size ;
791804 }
805+
806+ return result ;
792807 }
793808 };
794809}
795810
796- /// Context for the final partial leaf (may be smaller than chunk_size)
797811fn FinalLeafContext (comptime Variant : type ) type {
798812 return struct {
799813 view : * const MultiSliceView ,
800814 start_offset : usize ,
801- leaf_len : usize , // May be less than chunk_size
815+ leaf_len : usize ,
802816 output_cv : []align (@alignOf (u64 )) u8 ,
803817
804818 fn process (ctx : @This ()) void {
@@ -819,7 +833,6 @@ fn FinalLeafContext(comptime Variant: type) type {
819833 };
820834}
821835
822- /// Generic multi-threaded implementation with bounded heap allocation.
823836fn ktMultiThreaded (
824837 comptime Variant : type ,
825838 allocator : Allocator ,
@@ -853,40 +866,66 @@ fn ktMultiThreaded(
853866 const has_partial_leaf = (remaining_bytes % chunk_size ) != 0 ;
854867 const partial_leaf_size = if (has_partial_leaf ) remaining_bytes % chunk_size else 0 ;
855868
856- const max_concurrent_batches = 256 ;
857- const cvs_per_super_batch = max_concurrent_batches * leaves_per_batch * cv_size ;
858-
859- const cv_buf = try allocator .alignedAlloc (u8 , std .mem .Alignment .of (u64 ), cvs_per_super_batch );
860- defer allocator .free (cv_buf );
861-
862- var leaves_processed : usize = 0 ;
863- while (leaves_processed < full_leaves ) {
864- const leaves_in_super_batch = @min (max_concurrent_batches * leaves_per_batch , full_leaves - leaves_processed );
865- const num_batches = std .math .divCeil (usize , leaves_in_super_batch , leaves_per_batch ) catch unreachable ;
866-
867- var group : Io.Group = .init ;
869+ if (full_leaves > 0 ) {
870+ const total_batches = std .math .divCeil (usize , full_leaves , leaves_per_batch ) catch unreachable ;
871+ const max_concurrent : usize = @min (256 , total_batches );
872+
873+ const Result = BatchResult (Variant );
874+ const SelectResult = union (enum ) { batch : Result };
875+ const Select = Io .Select (SelectResult );
876+
877+ const select_buf = try allocator .alloc (SelectResult , max_concurrent );
878+ defer allocator .free (select_buf );
879+
880+ // Buffer for out-of-order results (select_buf slots get reused)
881+ const pending_cv_buf = try allocator .alloc ([leaves_per_batch * cv_size ]u8 , max_concurrent );
882+ defer allocator .free (pending_cv_buf );
883+ var pending_cv_lens : [256 ]usize = .{0 } ** 256 ;
884+
885+ var select : Select = .init (io , select_buf );
886+ var batches_spawned : usize = 0 ;
887+ var next_to_process : usize = 0 ;
888+
889+ while (next_to_process < total_batches ) {
890+ while (batches_spawned < total_batches and batches_spawned - next_to_process < max_concurrent ) {
891+ const batch_start_leaf = batches_spawned * leaves_per_batch ;
892+ const batch_leaves = @min (leaves_per_batch , full_leaves - batch_start_leaf );
893+ const start_offset = chunk_size + batch_start_leaf * chunk_size ;
894+
895+ select .async (.batch , SelectLeafContext (Variant ).process , .{SelectLeafContext (Variant ){
896+ .view = view ,
897+ .batch_idx = batches_spawned ,
898+ .start_offset = start_offset ,
899+ .num_leaves = batch_leaves ,
900+ }});
901+ batches_spawned += 1 ;
902+ }
868903
869- for (0 .. num_batches ) | batch_idx | {
870- const batch_start_leaf = leaves_processed + batch_idx * leaves_per_batch ;
871- const batch_leaves = @min ( leaves_per_batch , full_leaves - batch_start_leaf ) ;
904+ const result = select . wait () catch unreachable ;
905+ const batch = result . batch ;
906+ const slot = batch . batch_idx % max_concurrent ;
872907
873- if (batch_leaves == 0 ) break ;
908+ if (batch .batch_idx == next_to_process ) {
909+ final_state .update (batch .cvs [0.. batch .cv_len ]);
910+ next_to_process += 1 ;
874911
875- const start_offset = chunk_size + batch_start_leaf * chunk_size ;
876- const cv_start = batch_idx * leaves_per_batch * cv_size ;
912+ // Drain pending batches that are now ready
913+ while (next_to_process < total_batches ) {
914+ const pending_slot = next_to_process % max_concurrent ;
915+ const pending_len = pending_cv_lens [pending_slot ];
916+ if (pending_len == 0 ) break ;
877917
878- group .async (io , LeafThreadContext (Variant ).process , .{LeafThreadContext (Variant ){
879- .view = view ,
880- .start_offset = start_offset ,
881- .num_leaves = batch_leaves ,
882- .output_cvs = @alignCast (cv_buf [cv_start .. ][0 .. batch_leaves * cv_size ]),
883- }});
918+ final_state .update (pending_cv_buf [pending_slot ][0.. pending_len ]);
919+ pending_cv_lens [pending_slot ] = 0 ;
920+ next_to_process += 1 ;
921+ }
922+ } else {
923+ @memcpy (pending_cv_buf [slot ][0.. batch .cv_len ], batch .cvs [0.. batch .cv_len ]);
924+ pending_cv_lens [slot ] = batch .cv_len ;
925+ }
884926 }
885927
886- group .wait (io );
887-
888- final_state .update (cv_buf [0 .. leaves_in_super_batch * cv_size ]);
889- leaves_processed += leaves_in_super_batch ;
928+ select .group .wait (io );
890929 }
891930
892931 if (has_partial_leaf ) {
0 commit comments