Skip to content

Commit 6cc1f6e

Browse files
committed
add warmup
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 35a24b3 commit 6cc1f6e

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

vllm/model_executor/warmup/kernel_warmup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.v1.worker.gpu_worker import Worker
2121

2222

23-
def kernel_warmup(worker: "Worker"):
23+
def kernel_warmup(worker: "Worker", do_autotune: bool = False):
2424
# Deep GEMM warmup
2525
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
2626
and is_deep_gemm_supported()
@@ -32,10 +32,11 @@ def kernel_warmup(worker: "Worker"):
3232

3333
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
3434
if has_flashinfer() and current_platform.is_device_capability(100):
35-
flashinfer_autotune(worker.model_runner)
35+
flashinfer_autotune(worker.model_runner, do_autotune)
3636

3737

38-
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
38+
def flashinfer_autotune(runner: "GPUModelRunner",
39+
do_autotune: bool = True) -> None:
3940
"""
4041
Autotune FlashInfer operations.
4142
FlashInfer have many implementations for the same operation,
@@ -47,7 +48,7 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
4748
"""
4849
from vllm.utils.flashinfer import autotune
4950

50-
with torch.inference_mode(), autotune():
51+
with torch.inference_mode(), autotune(do_autotune):
5152
# We skip EPLB here since we don't want to record dummy metrics
5253
# When autotuning with number of tokens m, flashinfer will autotune
5354
# operations for all number of tokens up to m.

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def compile_or_warm_up_model(self) -> None:
313313
self.model_runner._dummy_run(size, skip_eplb=True)
314314

315315
# run autotuner before cuda graph capture.
316-
kernel_warmup(self)
316+
kernel_warmup(self, do_autotune=True)
317317

318318
if not self.model_config.enforce_eager:
319319
self.model_runner.capture_model()
@@ -339,6 +339,9 @@ def compile_or_warm_up_model(self) -> None:
339339
self.model_runner._dummy_sampler_run(
340340
hidden_states=last_hidden_states)
341341

342+
# Warmup kernels used during model execution
343+
kernel_warmup(self, do_autotune=False)
344+
342345
# Reset the seed to ensure that the random state is not affected by
343346
# the model initialization and profiling.
344347
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)