33
44//! Host interface for dynamic CUDA kernel dispatch.
55//!
6- //! An [`UnmaterializedPlan`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7- //! and flattens it into a linear sequence of stages. Call
8- //! [`materialize`](UnmaterializedPlan::materialize) to copy source buffers to
9- //! the device, producing a [`MaterializedPlan`] ready for kernel launch.
6+ //! [`DispatchPlan::new`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7+ //! in a single pass and returns one of three variants:
108//!
11- //! For partially-fusable trees, [`find_unfusable_nodes`] identifies nodes
12- //! that need separate kernels, and [`UnmaterializedPlan::new_with_subtree_inputs`] builds a plan
13- //! that incorporates their pre-executed arrays.
14- //!
15- //! Shared memory is dynamically sized at launch time via
16- //! [`UnmaterializedPlan::shared_mem_bytes`].
9+ //! - [`Fused`](DispatchPlan::Fused) — call [`FusedPlan::materialize`].
10+ //! - [`PartiallyFused`](DispatchPlan::PartiallyFused) — execute pending
11+ //! subtrees, then call [`FusedPlan::materialize_with_subtrees`].
12+ //! - [`Unfused`](DispatchPlan::Unfused) — fall back to single-kernel dispatch.
1713
1814#![ allow( non_upper_case_globals) ]
1915#![ allow( non_camel_case_types) ]
@@ -47,9 +43,9 @@ use crate::CudaDeviceBuffer;
4743use crate :: executor:: CudaExecutionCtx ;
4844
4945pub ( crate ) mod plan_builder;
46+ pub use plan_builder:: DispatchPlan ;
47+ pub use plan_builder:: FusedPlan ;
5048pub use plan_builder:: MaterializedPlan ;
51- pub use plan_builder:: UnmaterializedPlan ;
52- pub use plan_builder:: find_unfusable_nodes;
5349
5450include ! ( concat!( env!( "OUT_DIR" ) , "/dynamic_dispatch.rs" ) ) ;
5551
@@ -449,17 +445,18 @@ mod tests {
449445 use vortex:: session:: VortexSession ;
450446
451447 use super :: CudaDispatchPlan ;
448+ use super :: DispatchPlan ;
452449 use super :: SMEM_TILE_SIZE ;
453450 use super :: ScalarOp ;
454451 use super :: SourceOp ;
455452 use super :: Stage ;
456- use super :: UnmaterializedPlan ;
453+ use super :: * ;
457454 use crate :: CudaBufferExt ;
458455 use crate :: CudaDeviceBuffer ;
459456 use crate :: CudaExecutionCtx ;
460457 use crate :: session:: CudaSession ;
461458
462- fn make_bitpacked_array_u32 ( bit_width : u8 , len : usize ) -> BitPackedArray {
459+ fn bitpacked_array_u32 ( bit_width : u8 , len : usize ) -> BitPackedArray {
463460 let max_val = ( 1u64 << bit_width) . saturating_sub ( 1 ) ;
464461 let values: Vec < u32 > = ( 0 ..len)
465462 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 )
@@ -469,6 +466,16 @@ mod tests {
469466 . vortex_expect ( "failed to create BitPacked array" )
470467 }
471468
469+ fn dispatch_plan (
470+ array : & vortex:: array:: ArrayRef ,
471+ ctx : & CudaExecutionCtx ,
472+ ) -> VortexResult < MaterializedPlan > {
473+ match DispatchPlan :: new ( array) ? {
474+ DispatchPlan :: Fused ( plan) => plan. materialize ( ctx) ,
475+ _ => vortex_bail ! ( "array encoding not fusable" ) ,
476+ }
477+ }
478+
472479 #[ crate :: test]
473480 fn test_max_scalar_ops ( ) -> VortexResult < ( ) > {
474481 let bit_width: u8 = 6 ;
@@ -481,7 +488,7 @@ mod tests {
481488 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 + total_reference)
482489 . collect ( ) ;
483490
484- let bitpacked = make_bitpacked_array_u32 ( bit_width, len) ;
491+ let bitpacked = bitpacked_array_u32 ( bit_width, len) ;
485492 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
486493 let packed = bitpacked. packed ( ) . clone ( ) ;
487494 let device_input = futures:: executor:: block_on ( cuda_ctx. ensure_on_device ( packed) ) ?;
@@ -669,9 +676,9 @@ mod tests {
669676 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 )
670677 . collect ( ) ;
671678
672- let bp = make_bitpacked_array_u32 ( bit_width, len) ;
679+ let bp = bitpacked_array_u32 ( bit_width, len) ;
673680 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
674- let plan = UnmaterializedPlan :: new ( & bp. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
681+ let plan = dispatch_plan ( & bp. into_array ( ) , & cuda_ctx) ?;
675682
676683 let actual =
677684 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -692,11 +699,11 @@ mod tests {
692699 . collect ( ) ;
693700 let expected: Vec < u32 > = raw. iter ( ) . map ( |& v| v + reference) . collect ( ) ;
694701
695- let bp = make_bitpacked_array_u32 ( bit_width, len) ;
702+ let bp = bitpacked_array_u32 ( bit_width, len) ;
696703 let for_arr = FoRArray :: try_new ( bp. into_array ( ) , Scalar :: from ( reference) ) ?;
697704
698705 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
699- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
706+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
700707
701708 let actual =
702709 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -722,7 +729,7 @@ mod tests {
722729 let re = RunEndArray :: new ( ends_arr, values_arr) ;
723730
724731 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
725- let plan = UnmaterializedPlan :: new ( & re. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
732+ let plan = dispatch_plan ( & re. into_array ( ) , & cuda_ctx) ?;
726733
727734 let actual =
728735 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -755,7 +762,7 @@ mod tests {
755762 let dict = DictArray :: try_new ( codes_bp. into_array ( ) , dict_for. into_array ( ) ) ?;
756763
757764 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
758- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
765+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
759766
760767 let actual =
761768 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -787,7 +794,7 @@ mod tests {
787794 ) ;
788795
789796 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
790- let plan = UnmaterializedPlan :: new ( & tree. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
797+ let plan = dispatch_plan ( & tree. into_array ( ) , & cuda_ctx) ?;
791798
792799 let actual =
793800 run_dispatch_plan_f32 ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -816,7 +823,7 @@ mod tests {
816823 let zz = ZigZagArray :: try_new ( bp. into_array ( ) ) ?;
817824
818825 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
819- let plan = UnmaterializedPlan :: new ( & zz. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
826+ let plan = dispatch_plan ( & zz. into_array ( ) , & cuda_ctx) ?;
820827
821828 let actual =
822829 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -845,7 +852,7 @@ mod tests {
845852 let for_arr = FoRArray :: try_new ( re. into_array ( ) , Scalar :: from ( reference) ) ?;
846853
847854 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
848- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
855+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
849856
850857 let actual =
851858 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -874,7 +881,7 @@ mod tests {
874881 let for_arr = FoRArray :: try_new ( dict. into_array ( ) , Scalar :: from ( reference) ) ?;
875882
876883 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
877- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
884+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
878885
879886 let actual =
880887 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -902,7 +909,7 @@ mod tests {
902909 let dict = DictArray :: try_new ( codes_for. into_array ( ) , values_prim. into_array ( ) ) ?;
903910
904911 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
905- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
912+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
906913
907914 let actual =
908915 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -927,7 +934,7 @@ mod tests {
927934 let dict = DictArray :: try_new ( codes_bp. into_array ( ) , values_prim. into_array ( ) ) ?;
928935
929936 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
930- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
937+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
931938
932939 let actual =
933940 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -946,8 +953,11 @@ mod tests {
946953 let values_prim = PrimitiveArray :: new ( Buffer :: from ( dict_values) , NonNullable ) ;
947954 let dict = DictArray :: try_new ( codes_prim. into_array ( ) , values_prim. into_array ( ) ) ?;
948955
949- // UnmaterializedPlan::new should fail because u8 codes != u32 values in byte width.
950- assert ! ( UnmaterializedPlan :: new( & dict. into_array( ) ) . is_err( ) ) ;
956+ // DispatchPlan::new should return Unfused because u8 codes != u32 values in byte width.
957+ assert ! ( matches!(
958+ DispatchPlan :: new( & dict. into_array( ) ) ?,
959+ DispatchPlan :: Unfused
960+ ) ) ;
951961
952962 Ok ( ( ) )
953963 }
@@ -961,8 +971,11 @@ mod tests {
961971 let values_arr = PrimitiveArray :: new ( Buffer :: from ( values) , NonNullable ) . into_array ( ) ;
962972 let re = RunEndArray :: new ( ends_arr, values_arr) ;
963973
964- // UnmaterializedPlan::new should fail because u64 ends != i32 values in byte width.
965- assert ! ( UnmaterializedPlan :: new( & re. into_array( ) ) . is_err( ) ) ;
974+ // DispatchPlan::new should return Unfused because u64 ends != i32 values in byte width.
975+ assert ! ( matches!(
976+ DispatchPlan :: new( & re. into_array( ) ) ?,
977+ DispatchPlan :: Unfused
978+ ) ) ;
966979
967980 Ok ( ( ) )
968981 }
@@ -997,7 +1010,7 @@ mod tests {
9971010 let expected: Vec < u32 > = data[ slice_start..slice_end] . to_vec ( ) ;
9981011
9991012 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1000- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1013+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
10011014
10021015 let actual = run_dynamic_dispatch_plan (
10031016 & cuda_ctx,
@@ -1048,7 +1061,7 @@ mod tests {
10481061 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
10491062
10501063 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1051- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1064+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
10521065
10531066 let actual = run_dynamic_dispatch_plan (
10541067 & cuda_ctx,
@@ -1098,7 +1111,7 @@ mod tests {
10981111 . collect ( ) ;
10991112
11001113 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1101- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1114+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11021115
11031116 let actual = run_dynamic_dispatch_plan (
11041117 & cuda_ctx,
@@ -1143,7 +1156,7 @@ mod tests {
11431156 let expected: Vec < u32 > = data[ slice_start..slice_end] . to_vec ( ) ;
11441157
11451158 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1146- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1159+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11471160
11481161 let actual = run_dynamic_dispatch_plan (
11491162 & cuda_ctx,
@@ -1192,7 +1205,7 @@ mod tests {
11921205 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
11931206
11941207 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1195- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1208+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11961209
11971210 let actual = run_dynamic_dispatch_plan (
11981211 & cuda_ctx,
@@ -1244,7 +1257,7 @@ mod tests {
12441257 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
12451258
12461259 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1247- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1260+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
12481261
12491262 let actual = run_dynamic_dispatch_plan (
12501263 & cuda_ctx,
@@ -1301,7 +1314,7 @@ mod tests {
13011314 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
13021315
13031316 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1304- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1317+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
13051318
13061319 let actual = run_dynamic_dispatch_plan (
13071320 & cuda_ctx,
@@ -1333,7 +1346,7 @@ mod tests {
13331346 let seq = SequenceArray :: try_new_typed ( base, multiplier, Nullability :: NonNullable , len) ?;
13341347
13351348 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1336- let plan = UnmaterializedPlan :: new ( & seq. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
1349+ let plan = dispatch_plan ( & seq. into_array ( ) , & cuda_ctx) ?;
13371350
13381351 let actual = run_dynamic_dispatch_plan (
13391352 & cuda_ctx,
@@ -1366,7 +1379,7 @@ mod tests {
13661379 let seq = SequenceArray :: try_new_typed ( base, multiplier, Nullability :: NonNullable , len) ?;
13671380
13681381 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1369- let plan = UnmaterializedPlan :: new ( & seq. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
1382+ let plan = dispatch_plan ( & seq. into_array ( ) , & cuda_ctx) ?;
13701383
13711384 let actual_u32 = run_dynamic_dispatch_plan (
13721385 & cuda_ctx,
0 commit comments