Skip to content

Commit 63c9bd2

Browse files
committed
refactor(cuda): single-pass DispatchPlan builder
Refactor the dynamic dispatch plan builder to walk the encoding tree exactly once, discovering unfusable subtrees and computing shared memory requirements in the same pass. The result is a 3-variant enum (`Fused`, `PartiallyFused`, `Unfusable`) that replaces the previous `Result<Option<>>` API and eliminates the separate `find_unfusable_nodes` traversal. Shared memory is now validated upfront in `DispatchPlan::new` — before any subtree kernels are executed — so we never pay GPU cost for a plan that will not fit. The plan stages are split into `smem_stages` (fully decoded into persistent shared memory) and `output_stage` (tiled through a scratch region), making the two-phase kernel execution model explicit in the host-side data structures. Shared memory allocation invariants are documented on `FusedPlan`. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 5369859 commit 63c9bd2

File tree

4 files changed

+337
-264
lines changed

4 files changed

+337
-264
lines changed

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ use vortex_cuda::CudaDeviceBuffer;
4040
use vortex_cuda::CudaExecutionCtx;
4141
use vortex_cuda::CudaSession;
4242
use vortex_cuda::dynamic_dispatch::CudaDispatchPlan;
43+
use vortex_cuda::dynamic_dispatch::DispatchPlan;
4344
use vortex_cuda::dynamic_dispatch::MaterializedPlan;
44-
use vortex_cuda::dynamic_dispatch::UnmaterializedPlan;
4545
use vortex_cuda_macros::cuda_available;
4646
use vortex_cuda_macros::cuda_not_available;
4747

@@ -123,13 +123,15 @@ struct BenchRunner {
123123

124124
impl BenchRunner {
125125
fn new(array: &vortex::array::ArrayRef, len: usize, cuda_ctx: &CudaExecutionCtx) -> Self {
126+
let plan = match DispatchPlan::new(array).vortex_expect("build_dyn_dispatch_plan") {
127+
DispatchPlan::Fused(plan) => plan,
128+
_ => panic!("encoding not fusable"),
129+
};
126130
let MaterializedPlan {
127131
dispatch_plan,
128132
device_buffers,
129133
shared_mem_bytes,
130-
} = UnmaterializedPlan::new(array)
131-
.and_then(|p| p.materialize(cuda_ctx))
132-
.vortex_expect("build_dyn_dispatch_plan");
134+
} = plan.materialize(cuda_ctx).vortex_expect("materialize plan");
133135

134136
let device_plan = Arc::new(
135137
cuda_ctx

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33

44
//! Host interface for dynamic CUDA kernel dispatch.
55
//!
6-
//! An [`UnmaterializedPlan`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7-
//! and flattens it into a linear sequence of stages. Call
8-
//! [`materialize`](UnmaterializedPlan::materialize) to copy source buffers to
9-
//! the device, producing a [`MaterializedPlan`] ready for kernel launch.
6+
//! [`DispatchPlan::new`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7+
//! in a single pass and returns one of three variants:
108
//!
11-
//! For partially-fusable trees, [`find_unfusable_nodes`] identifies nodes
12-
//! that need separate kernels, and [`UnmaterializedPlan::new_with_subtree_inputs`] builds a plan
13-
//! that incorporates their pre-executed arrays.
14-
//!
15-
//! Shared memory is dynamically sized at launch time via
16-
//! [`UnmaterializedPlan::shared_mem_bytes`].
9+
//! - [`Fused`](DispatchPlan::Fused) — call [`FusedPlan::materialize`].
10+
//! - [`PartiallyFused`](DispatchPlan::PartiallyFused) — execute pending
11+
//! subtrees, then call [`FusedPlan::materialize_with_subtrees`].
12+
//! - [`Unfused`](DispatchPlan::Unfused) — fall back to single-kernel dispatch.
1713
1814
#![allow(non_upper_case_globals)]
1915
#![allow(non_camel_case_types)]
@@ -47,9 +43,9 @@ use crate::CudaDeviceBuffer;
4743
use crate::executor::CudaExecutionCtx;
4844

4945
pub(crate) mod plan_builder;
46+
pub use plan_builder::DispatchPlan;
47+
pub use plan_builder::FusedPlan;
5048
pub use plan_builder::MaterializedPlan;
51-
pub use plan_builder::UnmaterializedPlan;
52-
pub use plan_builder::find_unfusable_nodes;
5349

5450
include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs"));
5551

@@ -449,17 +445,18 @@ mod tests {
449445
use vortex::session::VortexSession;
450446

451447
use super::CudaDispatchPlan;
448+
use super::DispatchPlan;
452449
use super::SMEM_TILE_SIZE;
453450
use super::ScalarOp;
454451
use super::SourceOp;
455452
use super::Stage;
456-
use super::UnmaterializedPlan;
453+
use super::*;
457454
use crate::CudaBufferExt;
458455
use crate::CudaDeviceBuffer;
459456
use crate::CudaExecutionCtx;
460457
use crate::session::CudaSession;
461458

462-
fn make_bitpacked_array_u32(bit_width: u8, len: usize) -> BitPackedArray {
459+
fn bitpacked_array_u32(bit_width: u8, len: usize) -> BitPackedArray {
463460
let max_val = (1u64 << bit_width).saturating_sub(1);
464461
let values: Vec<u32> = (0..len)
465462
.map(|i| ((i as u64) % (max_val + 1)) as u32)
@@ -469,6 +466,16 @@ mod tests {
469466
.vortex_expect("failed to create BitPacked array")
470467
}
471468

469+
fn dispatch_plan(
470+
array: &vortex::array::ArrayRef,
471+
ctx: &CudaExecutionCtx,
472+
) -> VortexResult<MaterializedPlan> {
473+
match DispatchPlan::new(array)? {
474+
DispatchPlan::Fused(plan) => plan.materialize(ctx),
475+
_ => vortex_bail!("array encoding not fusable"),
476+
}
477+
}
478+
472479
#[crate::test]
473480
fn test_max_scalar_ops() -> VortexResult<()> {
474481
let bit_width: u8 = 6;
@@ -481,7 +488,7 @@ mod tests {
481488
.map(|i| ((i as u64) % (max_val + 1)) as u32 + total_reference)
482489
.collect();
483490

484-
let bitpacked = make_bitpacked_array_u32(bit_width, len);
491+
let bitpacked = bitpacked_array_u32(bit_width, len);
485492
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
486493
let packed = bitpacked.packed().clone();
487494
let device_input = futures::executor::block_on(cuda_ctx.ensure_on_device(packed))?;
@@ -669,9 +676,9 @@ mod tests {
669676
.map(|i| ((i as u64) % (max_val + 1)) as u32)
670677
.collect();
671678

672-
let bp = make_bitpacked_array_u32(bit_width, len);
679+
let bp = bitpacked_array_u32(bit_width, len);
673680
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
674-
let plan = UnmaterializedPlan::new(&bp.into_array())?.materialize(&cuda_ctx)?;
681+
let plan = dispatch_plan(&bp.into_array(), &cuda_ctx)?;
675682

676683
let actual =
677684
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -692,11 +699,11 @@ mod tests {
692699
.collect();
693700
let expected: Vec<u32> = raw.iter().map(|&v| v + reference).collect();
694701

695-
let bp = make_bitpacked_array_u32(bit_width, len);
702+
let bp = bitpacked_array_u32(bit_width, len);
696703
let for_arr = FoRArray::try_new(bp.into_array(), Scalar::from(reference))?;
697704

698705
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
699-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
706+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
700707

701708
let actual =
702709
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -722,7 +729,7 @@ mod tests {
722729
let re = RunEndArray::new(ends_arr, values_arr);
723730

724731
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
725-
let plan = UnmaterializedPlan::new(&re.into_array())?.materialize(&cuda_ctx)?;
732+
let plan = dispatch_plan(&re.into_array(), &cuda_ctx)?;
726733

727734
let actual =
728735
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -755,7 +762,7 @@ mod tests {
755762
let dict = DictArray::try_new(codes_bp.into_array(), dict_for.into_array())?;
756763

757764
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
758-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
765+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
759766

760767
let actual =
761768
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -787,7 +794,7 @@ mod tests {
787794
);
788795

789796
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
790-
let plan = UnmaterializedPlan::new(&tree.into_array())?.materialize(&cuda_ctx)?;
797+
let plan = dispatch_plan(&tree.into_array(), &cuda_ctx)?;
791798

792799
let actual =
793800
run_dispatch_plan_f32(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -816,7 +823,7 @@ mod tests {
816823
let zz = ZigZagArray::try_new(bp.into_array())?;
817824

818825
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
819-
let plan = UnmaterializedPlan::new(&zz.into_array())?.materialize(&cuda_ctx)?;
826+
let plan = dispatch_plan(&zz.into_array(), &cuda_ctx)?;
820827

821828
let actual =
822829
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -845,7 +852,7 @@ mod tests {
845852
let for_arr = FoRArray::try_new(re.into_array(), Scalar::from(reference))?;
846853

847854
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
848-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
855+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
849856

850857
let actual =
851858
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -874,7 +881,7 @@ mod tests {
874881
let for_arr = FoRArray::try_new(dict.into_array(), Scalar::from(reference))?;
875882

876883
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
877-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
884+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
878885

879886
let actual =
880887
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -902,7 +909,7 @@ mod tests {
902909
let dict = DictArray::try_new(codes_for.into_array(), values_prim.into_array())?;
903910

904911
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
905-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
912+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
906913

907914
let actual =
908915
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -927,7 +934,7 @@ mod tests {
927934
let dict = DictArray::try_new(codes_bp.into_array(), values_prim.into_array())?;
928935

929936
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
930-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
937+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
931938

932939
let actual =
933940
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -946,8 +953,11 @@ mod tests {
946953
let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable);
947954
let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?;
948955

949-
// UnmaterializedPlan::new should fail because u8 codes != u32 values in byte width.
950-
assert!(UnmaterializedPlan::new(&dict.into_array()).is_err());
956+
// DispatchPlan::new should return Unfused because u8 codes != u32 values in byte width.
957+
assert!(matches!(
958+
DispatchPlan::new(&dict.into_array())?,
959+
DispatchPlan::Unfused
960+
));
951961

952962
Ok(())
953963
}
@@ -961,8 +971,11 @@ mod tests {
961971
let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array();
962972
let re = RunEndArray::new(ends_arr, values_arr);
963973

964-
// UnmaterializedPlan::new should fail because u64 ends != i32 values in byte width.
965-
assert!(UnmaterializedPlan::new(&re.into_array()).is_err());
974+
// DispatchPlan::new should return Unfused because u64 ends != i32 values in byte width.
975+
assert!(matches!(
976+
DispatchPlan::new(&re.into_array())?,
977+
DispatchPlan::Unfused
978+
));
966979

967980
Ok(())
968981
}
@@ -997,7 +1010,7 @@ mod tests {
9971010
let expected: Vec<u32> = data[slice_start..slice_end].to_vec();
9981011

9991012
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1000-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1013+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
10011014

10021015
let actual = run_dynamic_dispatch_plan(
10031016
&cuda_ctx,
@@ -1048,7 +1061,7 @@ mod tests {
10481061
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
10491062

10501063
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1051-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1064+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
10521065

10531066
let actual = run_dynamic_dispatch_plan(
10541067
&cuda_ctx,
@@ -1098,7 +1111,7 @@ mod tests {
10981111
.collect();
10991112

11001113
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1101-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1114+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11021115

11031116
let actual = run_dynamic_dispatch_plan(
11041117
&cuda_ctx,
@@ -1143,7 +1156,7 @@ mod tests {
11431156
let expected: Vec<u32> = data[slice_start..slice_end].to_vec();
11441157

11451158
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1146-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1159+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11471160

11481161
let actual = run_dynamic_dispatch_plan(
11491162
&cuda_ctx,
@@ -1192,7 +1205,7 @@ mod tests {
11921205
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
11931206

11941207
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1195-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1208+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11961209

11971210
let actual = run_dynamic_dispatch_plan(
11981211
&cuda_ctx,
@@ -1244,7 +1257,7 @@ mod tests {
12441257
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
12451258

12461259
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1247-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1260+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
12481261

12491262
let actual = run_dynamic_dispatch_plan(
12501263
&cuda_ctx,
@@ -1301,7 +1314,7 @@ mod tests {
13011314
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
13021315

13031316
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1304-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1317+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
13051318

13061319
let actual = run_dynamic_dispatch_plan(
13071320
&cuda_ctx,
@@ -1333,7 +1346,7 @@ mod tests {
13331346
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
13341347

13351348
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1336-
let plan = UnmaterializedPlan::new(&seq.into_array())?.materialize(&cuda_ctx)?;
1349+
let plan = dispatch_plan(&seq.into_array(), &cuda_ctx)?;
13371350

13381351
let actual = run_dynamic_dispatch_plan(
13391352
&cuda_ctx,
@@ -1366,7 +1379,7 @@ mod tests {
13661379
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
13671380

13681381
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1369-
let plan = UnmaterializedPlan::new(&seq.into_array())?.materialize(&cuda_ctx)?;
1382+
let plan = dispatch_plan(&seq.into_array(), &cuda_ctx)?;
13701383

13711384
let actual_u32 = run_dynamic_dispatch_plan(
13721385
&cuda_ctx,

0 commit comments

Comments
 (0)