Skip to content

Conversation

@Amir-19
Copy link
Contributor

@Amir-19 Amir-19 commented Nov 21, 2025

Purpose

With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration. since ncclMemAlloc is tied to a stream, warm up and cuda graph capture on separate streams causes new memory allocations and thus window registrations (when you use the mempool associated with ncclMemAlloc and ncclCommWindowRegister on different streams, if there are no available segments, you need new allocations and thus registration.). in this PR, we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request makes the CUDA stream explicit when capturing CUDA graphs, which is a good practice for clarity and correctness. I've found one area for improvement in vllm/compilation/cuda_graph.py regarding the use of torch.cuda.current_stream(), where a more performant, cached version from vLLM's utilities should be used instead.

Comment on lines 172 to 176
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance reasons, it's better to use current_stream from vllm.utils.torch_utils instead of torch.cuda.current_stream(). The vLLM version is a cached version that avoids the overhead of creating a new stream object on each call, as documented in vllm/utils/torch_utils.py.

You'll need to update the import on line 20:

from vllm.utils.torch_utils import current_stream, weak_ref_tensors
Suggested change
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
):
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 23, 2025
Signed-off-by: Amir Samani <[email protected]>
@mergify mergify bot added the v1 label Nov 24, 2025
Signed-off-by: Amir Samani <[email protected]>
Signed-off-by: Amir Samani <[email protected]>
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, graph capture should happen on a side stream instead of current main stream?

We can add self.stream = torch.cuda.Stream() in the _init_ of CUDAGraphWrapper, and use this stream for both warmup and graph capture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the current_stream here already a non-default stream shared with warm-up iterations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a code pointer to a non-default stream shared with warm-up iterations? Looks like in cuda_graph.py, there is no explicit cuda stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check the fellow again to see where a context manager sets a new stream.

@Amir-19 Amir-19 changed the title [do not merge yet] add stream to cuda graph catpure add stream to cuda graph catpure Dec 1, 2025
@Amir-19 Amir-19 requested a review from BoyuanFeng December 4, 2025 19:04
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Amir-19 what's the context behind this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 this came out as the result of investigating #28901. With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration. since ncclMemAlloc is tied to a stream, warm up and cuda graph capture on separate streams causes new memory allocations and thus window registrations. in this PR, we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations. before this PR, cuda graph didn't have an explicit stream so it would create a new side stream, is that intentional?

Copy link
Contributor

@BoyuanFeng BoyuanFeng Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before this PR, cuda graph didn't have an explicit stream so it would create a new side stream, is that intentional?

yes. torch.cuda.graph(...) automatically creates a side stream, unless it is called with explicit stream torch.cuda.graph(..., stream=explicit_stream).

With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration.

Could you elaborate what is window registration?

we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations

In general, warm up on one stream and graph capture on another stream is fine, except some extra memory consumption. So using the same stream for warmup and capture is an optimization.

However, using different streams should not lead to an error. Could you elaborate a bit on why it errors? e.g., the window registration is a cudagraph unsafe op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BoyuanFeng ncclCommWindowRegister is used to register local buffers into NCCL window which enables us to use symmetric kernels. This window registration also requires the memory to be from a VMM-based allocators like ncclMemAlloc. since memory allocated using ncclMemAlloc is tied to a stream, when you use the mempool associated with ncclMemAlloc and ncclCommWindowRegister on different streams, if there are no available segments, you need new allocations and thus registration.

intuitively, graph capture means "do this every time graph replays", so even if registration was allowed during cuda graph capture, it would have lead to creating new window handles each time which required destroying them later. This is not efficient or useful. the proper pattern is to register the window once before capture, then reuse it.

starting with NCCL 2.28, there is this restriction that ncclCommWindowRegister should not be called during graph capture which caused the failure reported in #28901

to fix this we need to make sure that the warm up and cuda graph capture are on the same side stream.

graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the issue without this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this benchmark would fail with NCCL 2.28 as cuda graph would create a new side stream for capture.

batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
torch.cuda.stream(stream),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting context via torch.cuda.stream(stream) does not pass the stream to cudagraph capture. this is a no-op.

Would there be an issue w/o the change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this the call to the stream=torch.cuda.current_stream() in cuda_graph.py would return the default stream and capturing on the default stream is not allowed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not be necessary after 73094c7

with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a code pointer to a non-default stream shared with warm-up iterations? Looks like in cuda_graph.py, there is no explicit cuda stream.

set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
Copy link
Contributor

@BoyuanFeng BoyuanFeng Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before this PR, cuda graph didn't have an explicit stream so it would create a new side stream, is that intentional?

yes. torch.cuda.graph(...) automatically creates a side stream, unless it is called with explicit stream torch.cuda.graph(..., stream=explicit_stream).

With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration.

Could you elaborate what is window registration?

we explicitly set the stream for cuda graph capture and forcing it to be on the same stream as warm up iterations

In general, warm up on one stream and graph capture on another stream is fine, except some extra memory consumption. So using the same stream for warmup and capture is an optimization.

However, using different streams should not lead to an error. Could you elaborate a bit on why it errors? e.g., the window registration is a cudagraph unsafe op?

@Amir-19 Amir-19 requested review from BoyuanFeng and zou3519 December 7, 2025 06:57
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the long delay.

Taking a closer look at the error stack at #28901 , the error happens at https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/dev_runtime.cc#L591 , and the stream nccl uses to synchonize is created by nccl itself from https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/dev_runtime.cc#L584 , which is a local variable defined at https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/dev_runtime.cc#L553 . It should not interfere with anything outside nccl.

My hypothesis is there might be some caching inside the driver, that when we create many streams, some streams will actually be the same stream, causing the trouble here. The solution is, then to create less streams to reduce the stream collision.

NOTE: I think pytorch has some stream-caching:

import torch

assert torch.cuda.is_available(), "CUDA is required for CUDA streams"

streams = [torch.cuda.Stream() for _ in range(1000)]

# Collect stream pointers
ptrs = []
for i, s in enumerate(streams):
    # cuda_stream is an integer pointer (uintptr_t)
    ptr = s.cuda_stream
    ptrs.append(ptr)

# Check for duplicates
unique_ptrs = set(ptrs)
print("\nSummary:")
print(f"Total streams created: {len(ptrs)}")
print(f"Unique stream pointers: {len(unique_ptrs)}")

if len(unique_ptrs) != len(ptrs):
    print("⚠️ Duplicate streams detected!")
else:
    print("✅ No duplicate streams detected.")

It only prints:

Summary:
Total streams created: 1000
Unique stream pointers: 32
⚠️ Duplicate streams detected!

Which means I can create at most 32 unique streams from pytorch.

If I don't use pytorch, I can create 10k unique streams from cuda directly. It still seems to be a bug somewhere.

I'll try to dig further with driver / nccl team, but for now the PR looks good to me. Reducing the number of streams created by pytorch, can reduce the chance of stream collision (since pytorch can only create up to 32 streams).

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Dec 24, 2025
@youkaichao
Copy link
Member

Diving deeper, after turning on CUDA_LOG_FILE=stderr (only valid for cuda 12.9+), I can get more error information:

[06:56:50.922][136234295043968][CUDA][E] API call conflicts with a stream capture sequence initiated from the calling thread

The limitation imposed by this PR "forcing cudagraph capture stream to be the same stream as warm up iterations" seems to be either improper use of nccl or a driver bug.

Nevertheless, I will stop here and let nccl/driver team continue the investigation.

The fix in this PR is simple, and we can go ahead with this workaround.

@youkaichao youkaichao linked an issue Dec 24, 2025 that may be closed by this pull request
1 task
@youkaichao
Copy link
Member

there's an error in ci:

RuntimeError: CUDA graphs must be captured on a non-default stream. (However, after capture, it's ok to replay them on the default stream.)

@Amir-19 I think we need to update

def current_stream() -> torch.cuda.Stream:
too, to use a similar logic to rocm, to avoid using the default stream.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao changed the title add stream to cuda graph catpure use the same stream for cuda graph catpure and replay for NCCL Dec 25, 2025
@youkaichao
Copy link
Member

the only failing test comes from main. merging.

@youkaichao youkaichao merged commit 030fc44 into vllm-project:main Dec 25, 2025
47 of 49 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Dec 25, 2025
twjww pushed a commit to twjww/vllm that referenced this pull request Dec 28, 2025
…project#29207)

Signed-off-by: Amir Samani <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: tianwenjing <[email protected]>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Dec 30, 2025
aykoppol pushed a commit to aykoppol/vllm that referenced this pull request Jan 4, 2026
@pkousha
Copy link

pkousha commented Jan 6, 2026

@youkaichao Pouya from NCCL team here.

The limitation imposed by this PR "forcing cudagraph capture stream to be the same stream as warm up iterations" seems to be either improper use of nccl or a driver bug.
Nevertheless, I will stop here and let nccl/driver team continue the investigation.

I am not sure there is a bug in nccl VS a usage issue. let me explain: symmetric memory window registration is tied to a CUDA stream. Warm up is capturing that and putting things in a graph means "do this everytime graph is executed". Re-registering the same memory might be ok, we cache that at the physical layer, but we don't guarantee that the window handle written will be the same if the memory is already registered. So it literally means write over the window handle with a new window (that needs destroying later) every time. then who will free up that handle?
Do you still think there is a bug?

dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…project#29207)

Signed-off-by: Amir Samani <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: dsuhinin <[email protected]>
daje0601 pushed a commit to daje0601/vllm that referenced this pull request Jan 22, 2026
…project#29207)

Signed-off-by: Amir Samani <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: daje0601 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: nccl symmem causes nccl error on nccl 2.28+

6 participants