Skip to content

Commit f3137cd

Browse files
authored
[Core] Freeze gc during cuda graph capture to speed up init (#21146)
Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]>
1 parent 82ec66f commit f3137cd

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
142142
VLLM_USE_CUDNN_PREFILL: bool = False
143+
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
143144
VLLM_LOOPBACK_IP: str = ""
144145

145146

@@ -968,6 +969,12 @@ def get_vllm_port() -> Optional[int]:
968969
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
969970
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
970971

972+
# Controls garbage collection during CUDA graph capture.
973+
# If set to 0 (default), enables GC freezing to speed up capture time.
974+
# If set to 1, allows GC to run during capture.
975+
"VLLM_ENABLE_CUDAGRAPH_GC":
976+
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
977+
971978
# Used to force set up loopback IP
972979
"VLLM_LOOPBACK_IP":
973980
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2439,10 +2439,25 @@ def capture_model(self) -> None:
24392439
start_time = time.perf_counter()
24402440
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
24412441

2442+
@contextmanager
2443+
def freeze_gc():
2444+
# Optimize garbage collection during CUDA graph capture.
2445+
# Clean up, then freeze all remaining objects from being included
2446+
# in future collections.
2447+
gc.collect()
2448+
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
2449+
if should_freeze:
2450+
gc.freeze()
2451+
try:
2452+
yield
2453+
finally:
2454+
if should_freeze:
2455+
gc.unfreeze()
2456+
24422457
# Trigger CUDA graph capture for specific shapes.
24432458
# Capture the large shapes first so that the smaller shapes
24442459
# can reuse the memory pool allocated for the large shapes.
2445-
with graph_capture(device=self.device):
2460+
with freeze_gc(), graph_capture(device=self.device):
24462461
full_cg = self.full_cuda_graph
24472462
# Only rank 0 should print progress bar during capture
24482463
compilation_cases = reversed(self.cudagraph_batch_sizes)

0 commit comments

Comments
 (0)