Skip to content

Commit 62d54ba

Browse files
authored
[Model Runner V2] Optimize CUDA graph capture time (vllm-project#29275)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent b004c00 commit 62d54ba

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def capture_graph(
106106
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
107107
input_buffers.query_start_loc.np[batch_size:] = batch_size
108108
input_buffers.query_start_loc.copy_to_gpu()
109-
input_buffers.seq_lens[:batch_size] = self.max_model_len
109+
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
110+
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
111+
# seq_lens_np (CPU), which might cause issues in some attention backends.
112+
input_buffers.seq_lens[:batch_size] = 1
110113
input_buffers.seq_lens[batch_size:] = 0
111114

112115
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]

vllm/v1/worker/gpu/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def capture_model(self) -> int:
313313
return 0
314314

315315
start_time = time.perf_counter()
316+
gc.collect()
316317
torch.cuda.empty_cache()
317318
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
318319

0 commit comments

Comments
 (0)