1+ import contextlib
12import time
23from typing import Dict , List , Optional , Tuple , Set , Union
34
910 SchedulerConfig )
1011from vllm .logger import init_logger
1112from vllm .model_executor import get_model , InputMetadata , SamplingMetadata
13+ from vllm .model_executor .parallel_utils import cupy_utils
1214from vllm .model_executor .parallel_utils .communication_op import (
1315 broadcast_tensor_dict )
14- from vllm .model_executor .parallel_utils .cupy_utils import get_nccl_backend
1516from vllm .model_executor .parallel_utils .parallel_state import (
1617 with_cupy_nccl_for_all_reduce )
1718from vllm .model_executor .parallel_utils import custom_all_reduce
@@ -659,7 +660,7 @@ def list_loras(self) -> Set[int]:
659660 def capture_model (self , kv_caches : List [KVCache ]) -> None :
660661 # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
661662 # deleted before the CUDA graphs.
662- self .cupy_nccl_backend = get_nccl_backend ()
663+ self .cupy_nccl_backend = cupy_utils . get_nccl_backend ()
663664
664665 assert not self .model_config .enforce_eager
665666 logger .info ("Capturing the model for CUDA graphs. This may lead to "
@@ -689,15 +690,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
689690 bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
690691 ]
691692
692- # NOTE: Capturing the largest batch size first may help reduce the
693- # memory usage of CUDA graph.
694693 # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
695694 # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
696695 # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
697696 # graph, we use either custom all-reduce kernel or PyTorch NCCL.
698697 # We always prioritize using custom all-reduce kernel but fall back
699698 # to PyTorch or CuPy NCCL if it is disabled or not supported.
700699 with custom_all_reduce .capture ():
700+ # NOTE: Capturing the largest batch size first may help reduce the
701+ # memory usage of CUDA graph.
701702 for batch_size in reversed (batch_size_capture_list ):
702703 # Create dummy input_metadata.
703704 input_metadata = InputMetadata (
@@ -765,7 +766,7 @@ def capture(
765766 # Run the model once without capturing the graph.
766767 # This is to make sure that the captured graph does not include the
767768 # kernel launches for initial benchmarking (e.g., Triton autotune).
768- with with_cupy_nccl_for_all_reduce ():
769+ with _maybe_cupy_nccl ():
769770 self .model (
770771 input_ids ,
771772 positions ,
@@ -779,7 +780,7 @@ def capture(
779780 # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
780781 self .graph = torch .cuda .CUDAGraph ()
781782 with torch .cuda .graph (self .graph , pool = memory_pool ): # noqa: SIM117
782- with with_cupy_nccl_for_all_reduce ():
783+ with _maybe_cupy_nccl ():
783784 hidden_states = self .model (
784785 input_ids ,
785786 positions ,
@@ -830,6 +831,15 @@ def __call__(self, *args, **kwargs):
830831 return self .forward (* args , ** kwargs )
831832
832833
834+ @contextlib .contextmanager
835+ def _maybe_cupy_nccl ():
836+ if cupy_utils .is_initialized () and not custom_all_reduce .is_initialized ():
837+ with with_cupy_nccl_for_all_reduce ():
838+ yield
839+ else :
840+ yield
841+
842+
833843def _pad_to_max (x : List [int ], max_len : int , pad : int ) -> List [int ]:
834844 assert len (x ) <= max_len
835845 return x + [pad ] * (max_len - len (x ))
0 commit comments