[CUDA] Support user compute stream with CUDA graph in CUDA plugin EP#29221
Conversation
There was a problem hiding this comment.
Pull request overview
This PR brings the CUDA plugin Execution Provider to parity with the bundled CUDA EP by allowing CUDA Graph capture/replay (enable_cuda_graph) to be used together with a user-provided compute stream (user_compute_stream). It also includes supporting changes to make capture stable (avoiding capture-time allocations and cross-EP Memcpy nodes) and adds coverage for the new behavior.
Changes:
- Allow
user_compute_stream+enable_cuda_graphtogether in the CUDA plugin EP, capturing/replaying on the user stream. - Make CUDA graph capture more stable in the plugin EP by routing scratch allocations through the EP allocator and ensuring
Shapeis available on the CUDA EP to avoid Memcpy nodes. - Add a new plugin EP test that validates session creation and correctness across capture/replay and in-place input updates on the user stream.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc | Removes the invalid-argument restriction and documents the combined user-stream + CUDA-graph behavior. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | Updates per-thread CUDA-graph context to optionally use a user-owned stream and avoid destroying it. |
| onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h | Routes scratch/workspace allocations through the EP allocator instead of raw CUDA malloc paths. |
| onnxruntime/core/providers/cuda/tensor/shape_op.cc | Adds a plugin-build adapter-based Shape kernel while keeping output on CPU memory to prevent graph-breaking Memcpy nodes. |
| cmake/onnxruntime_providers_cuda_plugin.cmake | Includes shape_op.cc in the plugin build to enable the new plugin Shape implementation. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Adds null-allocator fallback in PrePack for plugin boundary robustness. |
| onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc | Adds null-allocator fallback in PrePack for plugin boundary robustness. |
| onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc | Adds null-allocator fallback in PrePack for plugin boundary robustness. |
| onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc | Adds coverage for combined user stream + CUDA graph capture/replay correctness. |
| onnxruntime/core/framework/session_state.cc | Pure formatting (line wrap). |
The pre-packing algo expects the allocator to be CPU based and in some cases attempts to externalize them to disk. #Resolved Refers to: onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:144 in 789c771. [](commit_id = 789c771, deletion_comment = False) |
Re: |
…atch Stream* - PerThreadContext: derive graph-stream ownership from explicit use_external_stream intent instead of (external_stream == nullptr), so a user-selected CUDA default stream (cudaStream_t(0)) combined with CUDA graph is treated as external/user-owned and not destroyed. - GetScratchBuffer: stop forwarding a stack-temporary PluginStreamShim Stream* to the stream-aware arena. A plugin kernel only has the raw cudaStream_t, not the framework OrtSyncStream* the arena persists per chunk and later dereferences, so the temporary would dangle and be type-confused. Pass a null stream; capture stability comes from arena chunk reuse, and the CUDA graph path runs on a single unified stream. - Reconcile arena/cuda-graph plugin docs with the null-stream scratch behavior.
Associated Issues / Duplicate Check
Missing negative tests
VerdictThe PR is functionally correct, well-documented, and conservative in its safety guarantees (disabling concurrent runs until the proper API exists). The code changes are minimal and well-targeted. The main concern is minor missing negative test coverage. The review comments were adequately addressed. Ready for merge pending the final approval from yuslepukhin (who already approved an earlier revision). |
Description
The CUDA plugin EP previously rejected combining a user-provided compute stream
(
user_compute_stream) with CUDA graph capture (enable_cuda_graph), returningORT_INVALID_ARGUMENT. This PR removes that restriction so the two options canbe used together: when both are set, graph capture and replay run on the
user-owned stream (the same stream the kernels are issued to), matching the
bundled (non-plugin) CUDA EP behavior. Several supporting fixes make capture on a
shared stream stable and Memcpy-free.
Summary of Changes
Allow user stream + CUDA graph
user_compute_stream+enable_cuda_graphtogether.PerThreadContextaccepts an optional external graph stream. When both options are set it captures/replays on the user stream and does not create or destroy it (the user owns its lifetime); otherwise it owns a dedicated graph stream as before.Stable, Memcpy-free CUDA graph capture
cudaMallocAsync/cudaMalloc. After warmup the arena reaches steady state, so the capture run serves scratch from already-reserved chunks and the device free-memory footprint stays stable — required for correct capture. Matches the built-in CUDA EP.Shapekernel under#ifdef BUILD_CUDA_EP_AS_PLUGINwith identical semantics to the CPUShape. RegisteringShapeon the EP keeps it off the CPU EP and avoids the Memcpy nodes that would otherwise break CUDA graph capture.shape_op.ccfrom the plugin build so the adapter-basedShapekernel is compiled in.Null-allocator fallback in PrePack (plugin boundary)
In the plugin build the
AllocatorPtrpassed toPrePackcan arrive null acrossthe library boundary. Each kernel now falls back to its own default-memory
allocator (
Info().GetAllocator(OrtMemTypeDefault)), which is always valid.Misc
Testing
user_compute_streamandenable_cuda_graphset (regression for the removed validation).ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EPand skip gracefully when no CUDA device or plugin library is available.Motivation and Context
Users that drive ORT from their own CUDA stream (e.g. to interleave ORT inference
with their own kernels) previously could not also benefit from CUDA graph capture
on the plugin EP. This change brings the plugin EP to parity with the bundled
CUDA EP for that workflow.
Checklist