1
+ import contextlib
1
2
import time
2
3
from typing import Dict , List , Optional , Tuple , Set , Union
3
4
9
10
SchedulerConfig )
10
11
from vllm .logger import init_logger
11
12
from vllm .model_executor import get_model , InputMetadata , SamplingMetadata
13
+ from vllm .model_executor .parallel_utils import cupy_utils
12
14
from vllm .model_executor .parallel_utils .communication_op import (
13
15
broadcast_tensor_dict )
14
- from vllm .model_executor .parallel_utils .cupy_utils import get_nccl_backend
15
16
from vllm .model_executor .parallel_utils .parallel_state import (
16
17
with_cupy_nccl_for_all_reduce )
17
18
from vllm .model_executor .parallel_utils import custom_all_reduce
@@ -659,7 +660,7 @@ def list_loras(self) -> Set[int]:
659
660
def capture_model (self , kv_caches : List [KVCache ]) -> None :
660
661
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
661
662
# deleted before the CUDA graphs.
662
- self .cupy_nccl_backend = get_nccl_backend ()
663
+ self .cupy_nccl_backend = cupy_utils . get_nccl_backend ()
663
664
664
665
assert not self .model_config .enforce_eager
665
666
logger .info ("Capturing the model for CUDA graphs. This may lead to "
@@ -689,15 +690,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
689
690
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
690
691
]
691
692
692
- # NOTE: Capturing the largest batch size first may help reduce the
693
- # memory usage of CUDA graph.
694
693
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
695
694
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
696
695
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
697
696
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
698
697
# We always prioritize using custom all-reduce kernel but fall back
699
698
# to PyTorch or CuPy NCCL if it is disabled or not supported.
700
699
with custom_all_reduce .capture ():
700
+ # NOTE: Capturing the largest batch size first may help reduce the
701
+ # memory usage of CUDA graph.
701
702
for batch_size in reversed (batch_size_capture_list ):
702
703
# Create dummy input_metadata.
703
704
input_metadata = InputMetadata (
@@ -765,7 +766,7 @@ def capture(
765
766
# Run the model once without capturing the graph.
766
767
# This is to make sure that the captured graph does not include the
767
768
# kernel launches for initial benchmarking (e.g., Triton autotune).
768
- with with_cupy_nccl_for_all_reduce ():
769
+ with _maybe_cupy_nccl ():
769
770
self .model (
770
771
input_ids ,
771
772
positions ,
@@ -779,7 +780,7 @@ def capture(
779
780
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
780
781
self .graph = torch .cuda .CUDAGraph ()
781
782
with torch .cuda .graph (self .graph , pool = memory_pool ): # noqa: SIM117
782
- with with_cupy_nccl_for_all_reduce ():
783
+ with _maybe_cupy_nccl ():
783
784
hidden_states = self .model (
784
785
input_ids ,
785
786
positions ,
@@ -830,6 +831,15 @@ def __call__(self, *args, **kwargs):
830
831
return self .forward (* args , ** kwargs )
831
832
832
833
834
+ @contextlib .contextmanager
835
+ def _maybe_cupy_nccl ():
836
+ if cupy_utils .is_initialized () and not custom_all_reduce .is_initialized ():
837
+ with with_cupy_nccl_for_all_reduce ():
838
+ yield
839
+ else :
840
+ yield
841
+
842
+
833
843
def _pad_to_max (x : List [int ], max_len : int , pad : int ) -> List [int ]:
834
844
assert len (x ) <= max_len
835
845
return x + [pad ] * (max_len - len (x ))
0 commit comments