Skip to content

Commit 25e86b6

Browse files
authored
Don't use cupy NCCL for AMD backends (#2855)
1 parent 4efbac6 commit 25e86b6

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

vllm/model_executor/parallel_utils/custom_all_reduce.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
6767
return _CA_HANDLE
6868

6969

70+
def is_initialized() -> bool:
71+
return _CA_HANDLE is not None
72+
73+
7074
@contextmanager
7175
def capture():
7276
try:

vllm/worker/model_runner.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import time
23
from typing import Dict, List, Optional, Tuple, Set, Union
34

@@ -9,9 +10,9 @@
910
SchedulerConfig)
1011
from vllm.logger import init_logger
1112
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
13+
from vllm.model_executor.parallel_utils import cupy_utils
1214
from vllm.model_executor.parallel_utils.communication_op import (
1315
broadcast_tensor_dict)
14-
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
1516
from vllm.model_executor.parallel_utils.parallel_state import (
1617
with_cupy_nccl_for_all_reduce)
1718
from vllm.model_executor.parallel_utils import custom_all_reduce
@@ -659,7 +660,7 @@ def list_loras(self) -> Set[int]:
659660
def capture_model(self, kv_caches: List[KVCache]) -> None:
660661
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
661662
# deleted before the CUDA graphs.
662-
self.cupy_nccl_backend = get_nccl_backend()
663+
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
663664

664665
assert not self.model_config.enforce_eager
665666
logger.info("Capturing the model for CUDA graphs. This may lead to "
@@ -689,15 +690,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
689690
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
690691
]
691692

692-
# NOTE: Capturing the largest batch size first may help reduce the
693-
# memory usage of CUDA graph.
694693
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
695694
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
696695
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
697696
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
698697
# We always prioritize using custom all-reduce kernel but fall back
699698
# to PyTorch or CuPy NCCL if it is disabled or not supported.
700699
with custom_all_reduce.capture():
700+
# NOTE: Capturing the largest batch size first may help reduce the
701+
# memory usage of CUDA graph.
701702
for batch_size in reversed(batch_size_capture_list):
702703
# Create dummy input_metadata.
703704
input_metadata = InputMetadata(
@@ -765,7 +766,7 @@ def capture(
765766
# Run the model once without capturing the graph.
766767
# This is to make sure that the captured graph does not include the
767768
# kernel launches for initial benchmarking (e.g., Triton autotune).
768-
with with_cupy_nccl_for_all_reduce():
769+
with _maybe_cupy_nccl():
769770
self.model(
770771
input_ids,
771772
positions,
@@ -779,7 +780,7 @@ def capture(
779780
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
780781
self.graph = torch.cuda.CUDAGraph()
781782
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
782-
with with_cupy_nccl_for_all_reduce():
783+
with _maybe_cupy_nccl():
783784
hidden_states = self.model(
784785
input_ids,
785786
positions,
@@ -830,6 +831,15 @@ def __call__(self, *args, **kwargs):
830831
return self.forward(*args, **kwargs)
831832

832833

834+
@contextlib.contextmanager
835+
def _maybe_cupy_nccl():
836+
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
837+
with with_cupy_nccl_for_all_reduce():
838+
yield
839+
else:
840+
yield
841+
842+
833843
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
834844
assert len(x) <= max_len
835845
return x + [pad] * (max_len - len(x))

vllm/worker/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.worker.cache_engine import CacheEngine
2020
from vllm.worker.model_runner import ModelRunner
2121
from vllm.lora.request import LoRARequest
22+
from vllm.utils import is_hip
2223

2324

2425
class Worker:
@@ -268,7 +269,8 @@ def init_distributed_environment(
268269
"cupy.distributed is already initialized but the cupy world "
269270
"size does not match parallel_config.world_size "
270271
f"({cupy_world_size} vs. {parallel_config.world_size}).")
271-
elif parallel_config.world_size > 1 and cupy_port is not None:
272+
elif (parallel_config.world_size > 1 and cupy_port is not None
273+
and not is_hip()):
272274
# NOTE(woosuk): We don't initialize CuPy process group when world size
273275
# is 1.
274276
# TODO(woosuk): Support multi-node connection.

0 commit comments

Comments
 (0)