diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index b414efa6e330..7b453fe7b680 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -293,7 +293,7 @@ def benchmark_allreduce_single( graph = torch.cuda.CUDAGraph() graph_pool = torch.cuda.graph_pool_handle() set_graph_pool_id(graph_pool) - with torch.cuda.graph(graph, pool=graph_pool): + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/tests/utils_/test_torch_utils.py b/tests/utils_/test_torch_utils.py index 0a30b9727f4d..f6a9486a1296 100644 --- a/tests/utils_/test_torch_utils.py +++ b/tests/utils_/test_torch_utils.py @@ -99,30 +99,18 @@ def child_thread_func(): def test_current_stream_multithread(): - from vllm.platforms import current_platform - if not torch.cuda.is_available(): pytest.skip("CUDA not available") - if current_platform.is_rocm(): - main_dedicated_stream = current_stream() - - assert main_dedicated_stream.cuda_stream != 0, ( - "ROCm should create a dedicated stream, not use default stream (0x0)" - ) - - main_stream_again = current_stream() - assert main_stream_again == main_dedicated_stream, ( - "Multiple calls to current_stream should return the same dedicated stream" - ) + main_dedicated_stream = current_stream() - _test_stream_thread(main_dedicated_stream) - else: - main_default_stream = torch.cuda.default_stream() - main_initial_stream = current_stream() + assert main_dedicated_stream.cuda_stream != 0, ( + "ROCm/CUDA should create a dedicated stream, not use default stream (0x0)" + ) - assert main_initial_stream == main_default_stream, ( - "First call to current_stream should return default stream on CUDA" - ) + main_stream_again = current_stream() + assert main_stream_again == main_dedicated_stream, ( + "Multiple calls to current_stream should return the same dedicated stream" + ) - _test_stream_thread(main_default_stream) + _test_stream_thread(main_dedicated_stream) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 0748643a5299..08cae27b1276 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -18,7 +18,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import weak_ref_tensors +from vllm.utils.torch_utils import current_stream, weak_ref_tensors logger = init_logger(__name__) @@ -263,7 +263,11 @@ def __call__(self, *args, **kwargs): else: set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + cudagraph, + pool=self.graph_pool, + stream=current_stream(), + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) if self.cudagraph_options.weak_ref_output: diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index b82e0171b7f7..db596052a04d 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -465,9 +465,13 @@ def current_stream() -> torch.cuda.Stream: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL - # is hurting performance. Therefore creating a dedicated stream - # per process - if current_platform.is_rocm(): + # is hurting performance. + # On CUDA, we capture and replay cudagraph on the same stream, + # so we need to avoid using the default stream as well. The default + # stream cannot be used for cudagraph capture, see + # https://github.com/pytorch/pytorch/blob/42ad9edfb754743fdae3276ade43de000beb4f60/aten/src/ATen/cuda/CUDAGraph.cpp#L77 + # for more details. Therefore, we create a dedicated stream per process. + if current_platform.is_rocm() or current_platform.is_cuda(): # torch.cuda.set_stream here is the alias of _pathed_set_stream torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu():