Skip to content
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_device_communicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def benchmark_allreduce_single(
graph = torch.cuda.CUDAGraph()
graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the issue without this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this benchmark would fail with NCCL 2.28 as cuda graph would create a new side stream for capture.

for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input)

Expand Down
16 changes: 10 additions & 6 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,16 @@ def test_capture_and_replay(self):
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
batch_descriptor = BatchDescriptor(num_tokens=10)

stream = torch.cuda.Stream()
# 0. global warmup
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
),
torch.cuda.stream(stream),
):
wrapper(self.input_tensor)

Expand All @@ -184,6 +187,7 @@ def test_capture_and_replay(self):
batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
torch.cuda.stream(stream),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting context via torch.cuda.stream(stream) does not pass the stream to cudagraph capture. this is a no-op.

Would there be an issue w/o the change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this the call to the stream=torch.cuda.current_stream() in cuda_graph.py would return the default stream and capturing on the default stream is not allowed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not be necessary after 73094c7

):
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output
Expand Down
6 changes: 5 additions & 1 deletion vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ def __call__(self, *args, **kwargs):
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Amir-19 what's the context behind this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 this came out as the result of investigating #28901. With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration. since ncclMemAlloc is tied to a stream, warm up and cuda graph capture on separate streams causes new memory allocations and thus window registrations. in this PR, we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations. before this PR, cuda graph didn't have an explicit stream so it would create a new side stream, is that intentional?

Copy link
Collaborator

@BoyuanFeng BoyuanFeng Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before this PR, cuda graph didn't have an explicit stream so it would create a new side stream, is that intentional?

yes. torch.cuda.graph(...) automatically creates a side stream, unless it is called with explicit stream torch.cuda.graph(..., stream=explicit_stream).

With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration.

Could you elaborate what is window registration?

we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations

In general, warm up on one stream and graph capture on another stream is fine, except some extra memory consumption. So using the same stream for warmup and capture is an optimization.

However, using different streams should not lead to an error. Could you elaborate a bit on why it errors? e.g., the window registration is a cudagraph unsafe op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BoyuanFeng ncclCommWindowRegister is used to register local buffers into NCCL window which enables us to use symmetric kernels. This window registration also requires the memory to be from a VMM-based allocators like ncclMemAlloc. since memory allocated using ncclMemAlloc is tied to a stream, when you use the mempool associated with ncclMemAlloc and ncclCommWindowRegister on different streams, if there are no available segments, you need new allocations and thus registration.

intuitively, graph capture means "do this every time graph replays", so even if registration was allowed during cuda graph capture, it would have lead to creating new window handles each time which required destroying them later. This is not efficient or useful. the proper pattern is to register the window once before capture, then reuse it.

starting with NCCL 2.28, there is this restriction that ncclCommWindowRegister should not be called during graph capture which caused the failure reported in #28901

to fix this we need to make sure that the warm up and cuda graph capture are on the same side stream.

cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, graph capture should happen on a side stream instead of current main stream?

We can add self.stream = torch.cuda.Stream() in the _init_ of CUDAGraphWrapper, and use this stream for both warmup and graph capture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the current_stream here already a non-default stream shared with warm-up iterations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a code pointer to a non-default stream shared with warm-up iterations? Looks like in cuda_graph.py, there is no explicit cuda stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check the fellow again to see where a context manager sets a new stream.

):
Comment on lines +266 to +270
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance reasons, it's better to use current_stream from vllm.utils.torch_utils instead of torch.cuda.current_stream(). The vLLM version is a cached version that avoids the overhead of creating a new stream object on each call, as documented in vllm/utils/torch_utils.py.

You'll need to update the import on line 20:

from vllm.utils.torch_utils import current_stream, weak_ref_tensors
Suggested change
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
):
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):

# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
Expand Down