Skip to content

Commit a4ff82f

Browse files
committed
add warmup
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 00e75ef commit a4ff82f

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

vllm/model_executor/warmup/kernel_warmup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.v1.worker.gpu_worker import Worker
2121

2222

23-
def kernel_warmup(worker: "Worker"):
23+
def kernel_warmup(worker: "Worker", do_autotune: bool = False):
2424
# Deep GEMM warmup
2525
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
2626
and is_deep_gemm_supported()
@@ -32,10 +32,11 @@ def kernel_warmup(worker: "Worker"):
3232

3333
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
3434
if has_flashinfer() and current_platform.is_device_capability(100):
35-
flashinfer_autotune(worker.model_runner)
35+
flashinfer_autotune(worker.model_runner, do_autotune)
3636

3737

38-
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
38+
def flashinfer_autotune(runner: "GPUModelRunner",
39+
do_autotune: bool = True) -> None:
3940
"""
4041
Autotune FlashInfer operations.
4142
FlashInfer have many implementations for the same operation,
@@ -47,7 +48,7 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
4748
"""
4849
from vllm.utils.flashinfer import autotune
4950

50-
with torch.inference_mode(), autotune():
51+
with torch.inference_mode(), autotune(do_autotune):
5152
# We skip EPLB here since we don't want to record dummy metrics
5253
# When autotuning with number of tokens m, flashinfer will autotune
5354
# operations for all number of tokens up to m.

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def compile_or_warm_up_model(self) -> None:
312312
self.model_runner._dummy_run(size, skip_eplb=True)
313313

314314
# run autotuner before cuda graph capture.
315-
kernel_warmup(self)
315+
kernel_warmup(self, do_autotune=True)
316316

317317
if not self.model_config.enforce_eager:
318318
self.model_runner.capture_model()
@@ -338,6 +338,9 @@ def compile_or_warm_up_model(self) -> None:
338338
self.model_runner._dummy_sampler_run(
339339
hidden_states=last_hidden_states)
340340

341+
# Warmup kernels used during model execution
342+
kernel_warmup(self, do_autotune=False)
343+
341344
# Reset the seed to ensure that the random state is not affected by
342345
# the model initialization and profiling.
343346
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)