Skip to content

[CUDA Plugin EP] Expose kernel sync stream for scratch allocation#29244

Open
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/concurrent_stream
Open

[CUDA Plugin EP] Expose kernel sync stream for scratch allocation#29244
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/concurrent_stream

Conversation

@tianleiwu

Copy link
Copy Markdown
Contributor

Description

This PR adds a kernel-context C API accessor for the framework OrtSyncStream* and uses it in the CUDA plugin EP so scratch allocations can be tagged with the actual compute stream selected for the kernel. It is stacked on #29221 and turns the previously documented concurrent multi-stream limitation into a gated capability: older runtimes keep the conservative fallback, while runtimes with the new API can safely advertise concurrent runs when EP-level unified stream mode is not forced.

Summary of Changes

Public API and Adapters

File Change
include/onnxruntime/core/session/onnxruntime_c_api.h Adds KernelContext_GetSyncStream to expose the borrowed framework stream for stream-aware allocation and synchronization bookkeeping.
onnxruntime/core/session/custom_ops.cc Implements the API by retrieving the kernel's OpKernelContext::GetComputeStream() inside ORT core.
onnxruntime/core/session/ort_apis.h and onnxruntime/core/session/onnxruntime_c_api.cc Declares and wires the new API entry.
include/onnxruntime/core/session/onnxruntime_cxx_api.h and include/onnxruntime/core/session/onnxruntime_cxx_inline.h Adds the C++ Ort::KernelContext::GetSyncStream() wrapper.
include/onnxruntime/ep/adapter/op_kernel.h Adds a version-gated EP adapter accessor so plugins can use the API when available and fall back safely otherwise.

CUDA Plugin EP

  • Tracks the framework stream corresponding to both raw CUDA stream handles and OrtStreamAdapter stream arguments.
  • Passes the framework stream to scratch allocation so arena chunks are stream-tagged instead of using a null stream tag.
  • Re-enables concurrent run support only when KernelContext_GetSyncStream is available and EP-level unified stream mode is not forced.

Tests and Docs

  • Extends the shared-lib custom-op test helper to exercise Ort::KernelContext::GetSyncStream().
  • Updates CUDA plugin EP docs to describe stream-tagged scratch allocation, compatibility fallback, and the new API audit entry.

Why a C API is needed

The implementation of KernelContext_GetSyncStream is intentionally small, but the API boundary is the important part. ORT core can safely cast OrtKernelContext* back to onnxruntime::OpKernelContext* because it owns both the opaque C handle and the private C++ implementation. A plugin kernel should not perform that cast directly: it would make the plugin depend on ORT-core private C++ layout, vtables, and exact build compatibility.

The new API keeps that private cast inside ORT core and gives plugin kernels a stable ABI entry point:

plugin kernel -> opaque OrtKernelContext* -> OrtApi::KernelContext_GetSyncStream -> ORT core retrieves the actual framework stream

This also lets the plugin use runtime version gating. When loaded by an older ORT runtime that does not expose the API, the adapter returns null, scratch allocation uses the conservative fallback, and concurrent runs are not advertised.

Testing

  • lintrunner -a
  • ninja -C build/cu130_plugin/Debug onnxruntime_providers_cuda_plugin
  • ninja -C build/cu130_plugin/Debug onnxruntime_shared_lib_test
  • cd build/cu130_plugin/Debug && ./onnxruntime_shared_lib_test --gtest_filter=CApiTest.custom_op_handler --gtest_color=no
  • VS Code diagnostics on touched C++ and header files

Checklist

  • Tests added/updated
  • Documentation updated
  • Backward compatibility guarded by runtime API-version checks
  • CI passes

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the kernel-context C API to expose the framework OrtSyncStream* for the current kernel invocation, and updates the CUDA plugin EP to use that stream for stream-aware scratch allocation bookkeeping so it can safely advertise concurrent Session::Run() when supported by the host runtime.

Changes:

  • Adds OrtApi::KernelContext_GetSyncStream (plus C++ and adapter wrappers) to retrieve the framework stream wrapper associated with a kernel context.
  • Updates the CUDA plugin kernel adapter to associate scratch/workspace allocations with the framework stream (instead of a null stream tag).
  • Re-enables CUDA plugin EP concurrent-run support when the host runtime supports the new API and unified-stream mode is not forced.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
include/onnxruntime/core/session/onnxruntime_c_api.h Adds the public C API entry and documentation for KernelContext_GetSyncStream.
onnxruntime/core/session/ort_apis.h Declares the new OrtApis implementation entry point.
onnxruntime/core/session/onnxruntime_c_api.cc Wires the new function pointer into the OrtApi table.
onnxruntime/core/session/custom_ops.cc Implements KernelContext_GetSyncStream by returning the kernel’s framework compute stream wrapper.
include/onnxruntime/core/session/onnxruntime_cxx_api.h Adds Ort::KernelContext::GetSyncStream() declaration.
include/onnxruntime/core/session/onnxruntime_cxx_inline.h Implements the C++ wrapper calling into the C API.
include/onnxruntime/ep/adapter/op_kernel.h Adds version-gated adapter access to GetSyncStream() for plugin kernels.
onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h Tracks/uses the framework stream wrapper for stream-aware scratch allocation tagging.
onnxruntime/core/providers/cuda/plugin/cuda_ep.cc Gates IsConcurrentRunSupported on API availability and unified-stream configuration.
onnxruntime/test/shared_lib/custom_op_utils.cc Extends shared-lib custom-op tests to exercise GetSyncStream().
docs/cuda_plugin_ep/cuda_plugin_ep_design.md Updates plugin design docs for the new gated stream-tagged scratch capability.
docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md Updates CUDA graph docs to reflect stream-tagged scratch allocation and concurrent-run conditions.
docs/cuda_plugin_ep/arena_allocator_migration_design.md Updates allocator migration design docs to reflect stream-tagged scratch allocation and compatibility behavior.

Comment thread onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h
Base automatically changed from tlwu/20260623/cuda_plugin_ep_cuda_graph_stream to main June 25, 2026 00:20
Comment thread .github/workflows/linux_cuda_plugin_ci.yml Outdated
@tianleiwu tianleiwu force-pushed the tlwu/concurrent_stream branch from e3d1cc9 to 39c5fe2 Compare June 25, 2026 00:45

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 2 comments.

Comment on lines 958 to +961
inline void* GetComputeStream(OpKernelContext* ctx) const {
return ctx->GetGPUComputeStream();
void* cuda_stream = ctx->GetGPUComputeStream();
cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream());
return cuda_stream;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed. Stream(ctx) now calls RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()) for the current Compute call, so kernels that call Stream(ctx) and then GetTransientScratchBuffer()/GetScratchBuffer(..., nullptr) before ever calling GetComputeStream()/GetOrtStream() (e.g. conv algo search) get correctly stream-tagged scratch instead of a stale/null framework stream.

Comment on lines +63 to +76
inline void RegisterFrameworkStreamForCudaStream(void* cuda_stream, OrtSyncStream* framework_stream) {
current_cuda_stream = cuda_stream;
current_framework_stream = reinterpret_cast<onnxruntime::Stream*>(framework_stream);

if (current_framework_stream == nullptr) {
return;
}

stream_to_framework_stream[current_framework_stream] = current_framework_stream;

if (cuda_stream != nullptr) {
stream_to_framework_stream[cuda_stream] = current_framework_stream;
}
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Removed the stream_to_framework_stream[current_framework_stream] = current_framework_stream self-entry. GetFrameworkStreamForStreamArg already handles stream == current_framework_stream directly, so the entry was unused and only risked unbounded thread-local map growth and retaining framework stream pointers past the Session::Run() teardown lifetime. The map now keys only off raw cudaStream_t handles.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants