File tree Expand file tree Collapse file tree 2 files changed +23
-1
lines changed Expand file tree Collapse file tree 2 files changed +23
-1
lines changed Original file line number Diff line number Diff line change 140
140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB : Optional [int ] = None
141
141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT : int = 120
142
142
VLLM_USE_CUDNN_PREFILL : bool = False
143
+ VLLM_ENABLE_CUDAGRAPH_GC : bool = False
143
144
VLLM_LOOPBACK_IP : str = ""
144
145
145
146
@@ -968,6 +969,12 @@ def get_vllm_port() -> Optional[int]:
968
969
"VLLM_USE_TRTLLM_DECODE_ATTENTION" :
969
970
lambda : os .getenv ("VLLM_USE_TRTLLM_DECODE_ATTENTION" , None ),
970
971
972
+ # Controls garbage collection during CUDA graph capture.
973
+ # If set to 0 (default), enables GC freezing to speed up capture time.
974
+ # If set to 1, allows GC to run during capture.
975
+ "VLLM_ENABLE_CUDAGRAPH_GC" :
976
+ lambda : bool (int (os .getenv ("VLLM_ENABLE_CUDAGRAPH_GC" , "0" ))),
977
+
971
978
# Used to force set up loopback IP
972
979
"VLLM_LOOPBACK_IP" :
973
980
lambda : os .getenv ("VLLM_LOOPBACK_IP" , "" ),
Original file line number Diff line number Diff line change @@ -2439,10 +2439,25 @@ def capture_model(self) -> None:
2439
2439
start_time = time .perf_counter ()
2440
2440
start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
2441
2441
2442
+ @contextmanager
2443
+ def freeze_gc ():
2444
+ # Optimize garbage collection during CUDA graph capture.
2445
+ # Clean up, then freeze all remaining objects from being included
2446
+ # in future collections.
2447
+ gc .collect ()
2448
+ should_freeze = not envs .VLLM_ENABLE_CUDAGRAPH_GC
2449
+ if should_freeze :
2450
+ gc .freeze ()
2451
+ try :
2452
+ yield
2453
+ finally :
2454
+ if should_freeze :
2455
+ gc .unfreeze ()
2456
+
2442
2457
# Trigger CUDA graph capture for specific shapes.
2443
2458
# Capture the large shapes first so that the smaller shapes
2444
2459
# can reuse the memory pool allocated for the large shapes.
2445
- with graph_capture (device = self .device ):
2460
+ with freeze_gc (), graph_capture (device = self .device ):
2446
2461
full_cg = self .full_cuda_graph
2447
2462
# Only rank 0 should print progress bar during capture
2448
2463
compilation_cases = reversed (self .cudagraph_batch_sizes )
You can’t perform that action at this time.
0 commit comments