-
-
Notifications
You must be signed in to change notification settings - Fork 13.9k
use the same stream for cuda graph catpure and replay for NCCL #29207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
28579b5
695d78c
dfb5ce4
d1f2b3b
f5f1ab2
aeac905
2d9628c
fd593c5
73094c7
790fe09
95e96f0
208a882
f5d378b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,13 +165,16 @@ def test_capture_and_replay(self): | |
| self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL | ||
| ) | ||
| batch_descriptor = BatchDescriptor(num_tokens=10) | ||
|
|
||
| stream = torch.cuda.Stream() | ||
| # 0. global warmup | ||
| with set_forward_context( | ||
| attn_metadata=None, | ||
| vllm_config=self.vllm_config, | ||
| cudagraph_runtime_mode=CUDAGraphMode.NONE, | ||
| batch_descriptor=None, | ||
| with ( | ||
| set_forward_context( | ||
| attn_metadata=None, | ||
| vllm_config=self.vllm_config, | ||
| cudagraph_runtime_mode=CUDAGraphMode.NONE, | ||
| batch_descriptor=None, | ||
| ), | ||
| torch.cuda.stream(stream), | ||
| ): | ||
| wrapper(self.input_tensor) | ||
|
|
||
|
|
@@ -184,6 +187,7 @@ def test_capture_and_replay(self): | |
| batch_descriptor=batch_descriptor, | ||
| ), | ||
| patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, | ||
| torch.cuda.stream(stream), | ||
|
||
| ): | ||
| output1 = wrapper(self.input_tensor) | ||
| # capturing phase should generate a zero output | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -169,7 +169,11 @@ def __call__(self, *args, **kwargs): | |||||||||||||||||||||
| else: | ||||||||||||||||||||||
| set_graph_pool_id(current_platform.graph_pool_handle()) | ||||||||||||||||||||||
| # mind-exploding: carefully manage the reference and memory. | ||||||||||||||||||||||
| with torch.cuda.graph(cudagraph, pool=self.graph_pool): | ||||||||||||||||||||||
| with torch.cuda.graph( | ||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Amir-19 what's the context behind this PR?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zou3519 this came out as the result of investigating #28901. With NCCL 2.28+, window registration during cuda graph capture crashes with NCCL complaining about the registration. since
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes.
Could you elaborate what is
In general, warm up on one stream and graph capture on another stream is fine, except some extra memory consumption. So using the same stream for warmup and capture is an optimization. However, using different streams should not lead to an error. Could you elaborate a bit on why it errors? e.g., the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BoyuanFeng intuitively, graph capture means "do this every time graph replays", so even if registration was allowed during cuda graph capture, it would have lead to creating new window handles each time which required destroying them later. This is not efficient or useful. the proper pattern is to register the window once before capture, then reuse it. starting with NCCL 2.28, there is this restriction that to fix this we need to make sure that the warm up and cuda graph capture are on the same side stream. |
||||||||||||||||||||||
| cudagraph, | ||||||||||||||||||||||
| pool=self.graph_pool, | ||||||||||||||||||||||
| stream=torch.cuda.current_stream(), | ||||||||||||||||||||||
|
||||||||||||||||||||||
| ): | ||||||||||||||||||||||
|
Comment on lines
+266
to
+270
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance reasons, it's better to use You'll need to update the import on line 20: from vllm.utils.torch_utils import current_stream, weak_ref_tensors
Suggested change
|
||||||||||||||||||||||
| # `output` is managed by pytorch's cudagraph pool | ||||||||||||||||||||||
| output = self.runnable(*args, **kwargs) | ||||||||||||||||||||||
| if self.cudagraph_options.weak_ref_output: | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would be the issue without this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this benchmark would fail with NCCL 2.28 as cuda graph would create a new side stream for capture.