20
20
from vllm .v1 .worker .gpu_worker import Worker
21
21
22
22
23
- def kernel_warmup (worker : "Worker" ):
23
+ def kernel_warmup (worker : "Worker" , do_autotune : bool = False ):
24
24
# Deep GEMM warmup
25
25
do_deep_gemm_warmup = (envs .VLLM_USE_DEEP_GEMM
26
26
and is_deep_gemm_supported ()
@@ -32,10 +32,11 @@ def kernel_warmup(worker: "Worker"):
32
32
33
33
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
34
34
if has_flashinfer () and current_platform .is_device_capability (100 ):
35
- flashinfer_autotune (worker .model_runner )
35
+ flashinfer_autotune (worker .model_runner , do_autotune )
36
36
37
37
38
- def flashinfer_autotune (runner : "GPUModelRunner" ) -> None :
38
+ def flashinfer_autotune (runner : "GPUModelRunner" ,
39
+ do_autotune : bool = True ) -> None :
39
40
"""
40
41
Autotune FlashInfer operations.
41
42
FlashInfer have many implementations for the same operation,
@@ -47,7 +48,7 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
47
48
"""
48
49
from vllm .utils .flashinfer import autotune
49
50
50
- with torch .inference_mode (), autotune ():
51
+ with torch .inference_mode (), autotune (do_autotune ):
51
52
# We skip EPLB here since we don't want to record dummy metrics
52
53
# When autotuning with number of tokens m, flashinfer will autotune
53
54
# operations for all number of tokens up to m.
0 commit comments