From 35a24b32cd682dd53491df4a6d53d19524f45603 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Tue, 19 Aug 2025 14:45:03 -0700 Subject: [PATCH 1/9] fix after flashinfer autotuner Signed-off-by: Siyuan Fu --- vllm/model_executor/layers/quantization/mxfp4.py | 9 ++++++++- vllm/v1/worker/gpu_worker.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6a190ebbc063..07ed8e27a37a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -113,6 +113,8 @@ def __init__(self, moe: FusedMoEConfig): self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() + self.device_support_pdl = current_platform.is_cuda( + ) and current_platform.has_device_capability(90) if current_platform.is_device_capability(100) and not has_flashinfer(): logger.warning_once( @@ -520,7 +522,8 @@ def apply( x_scale = None else: x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x.shape[:-1], -1) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -549,6 +552,10 @@ def apply( self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize + self.device_support_pdl, + None, # output + # TODO: use the maximum number in the cudagraph_batch_sizes + 8192, # tune_max_num_tokens. )[0] return trtllm_gen_output else: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d61177d4245d..6fd5b2957947 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -312,6 +312,9 @@ def compile_or_warm_up_model(self) -> None: logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) + # run autotuner before cuda graph capture. + kernel_warmup(self) + if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -336,9 +339,6 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) - # Warmup kernels used during model execution - kernel_warmup(self) - # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 6cc1f6e5b69be0bfe00530d2527192f815f68788 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Tue, 19 Aug 2025 15:41:06 -0700 Subject: [PATCH 2/9] add warmup Signed-off-by: Siyuan Fu --- vllm/model_executor/warmup/kernel_warmup.py | 9 +++++---- vllm/v1/worker/gpu_worker.py | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 761172e4d361..a30c469589a9 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -20,7 +20,7 @@ from vllm.v1.worker.gpu_worker import Worker -def kernel_warmup(worker: "Worker"): +def kernel_warmup(worker: "Worker", do_autotune: bool = False): # Deep GEMM warmup do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() @@ -32,10 +32,11 @@ def kernel_warmup(worker: "Worker"): # FlashInfer autotune for Blackwell (SM 10.0) GPUs if has_flashinfer() and current_platform.is_device_capability(100): - flashinfer_autotune(worker.model_runner) + flashinfer_autotune(worker.model_runner, do_autotune) -def flashinfer_autotune(runner: "GPUModelRunner") -> None: +def flashinfer_autotune(runner: "GPUModelRunner", + do_autotune: bool = True) -> None: """ Autotune FlashInfer operations. FlashInfer have many implementations for the same operation, @@ -47,7 +48,7 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ from vllm.utils.flashinfer import autotune - with torch.inference_mode(), autotune(): + with torch.inference_mode(), autotune(do_autotune): # We skip EPLB here since we don't want to record dummy metrics # When autotuning with number of tokens m, flashinfer will autotune # operations for all number of tokens up to m. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6fd5b2957947..440bd8ececea 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -313,7 +313,7 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_run(size, skip_eplb=True) # run autotuner before cuda graph capture. - kernel_warmup(self) + kernel_warmup(self, do_autotune=True) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -339,6 +339,9 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) + # Warmup kernels used during model execution + kernel_warmup(self, do_autotune=False) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 23055b80b4702f3140f47cb87a9364ed3b76a888 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Wed, 20 Aug 2025 09:44:27 -0700 Subject: [PATCH 3/9] address comment Signed-off-by: Siyuan Fu --- vllm/model_executor/warmup/kernel_warmup.py | 9 ++++----- vllm/v1/worker/gpu_worker.py | 8 +++----- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index a30c469589a9..761172e4d361 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -20,7 +20,7 @@ from vllm.v1.worker.gpu_worker import Worker -def kernel_warmup(worker: "Worker", do_autotune: bool = False): +def kernel_warmup(worker: "Worker"): # Deep GEMM warmup do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() @@ -32,11 +32,10 @@ def kernel_warmup(worker: "Worker", do_autotune: bool = False): # FlashInfer autotune for Blackwell (SM 10.0) GPUs if has_flashinfer() and current_platform.is_device_capability(100): - flashinfer_autotune(worker.model_runner, do_autotune) + flashinfer_autotune(worker.model_runner) -def flashinfer_autotune(runner: "GPUModelRunner", - do_autotune: bool = True) -> None: +def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ Autotune FlashInfer operations. FlashInfer have many implementations for the same operation, @@ -48,7 +47,7 @@ def flashinfer_autotune(runner: "GPUModelRunner", """ from vllm.utils.flashinfer import autotune - with torch.inference_mode(), autotune(do_autotune): + with torch.inference_mode(), autotune(): # We skip EPLB here since we don't want to record dummy metrics # When autotuning with number of tokens m, flashinfer will autotune # operations for all number of tokens up to m. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 440bd8ececea..e810bee6877e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -312,8 +312,9 @@ def compile_or_warm_up_model(self) -> None: logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) - # run autotuner before cuda graph capture. - kernel_warmup(self, do_autotune=True) + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -339,9 +340,6 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) - # Warmup kernels used during model execution - kernel_warmup(self, do_autotune=False) - # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 7e1fb28e00b6aefc5563a592160e4782d5cfccac Mon Sep 17 00:00:00 2001 From: siyuanf Date: Wed, 20 Aug 2025 23:39:09 -0700 Subject: [PATCH 4/9] Update flashinfer tag Signed-off-by: siyuanf --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6a3013de7937..7317b60f81a3 100644 --- a/setup.py +++ b/setup.py @@ -685,7 +685,7 @@ def _read_requirements(filename: str) -> list[str]: "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.12"], + "flashinfer": ["flashinfer-python==0.2.13"], }, cmdclass=cmdclass, package_data=package_data, From 3d9f10eb41dc3d6af8bc06e0c9a0211d3d18b135 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 21 Aug 2025 11:24:49 -0700 Subject: [PATCH 5/9] address comment Signed-off-by: Siyuan Fu --- vllm/model_executor/layers/quantization/mxfp4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 07ed8e27a37a..036127a4b0de 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -552,10 +552,8 @@ def apply( self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize - self.device_support_pdl, - None, # output # TODO: use the maximum number in the cudagraph_batch_sizes - 8192, # tune_max_num_tokens. + tune_max_num_tokens=8192, )[0] return trtllm_gen_output else: From 8c78c127977e851526a6c4e06d80d817a201ae57 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 21 Aug 2025 11:26:42 -0700 Subject: [PATCH 6/9] address comment Signed-off-by: Siyuan Fu --- vllm/model_executor/layers/quantization/mxfp4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 036127a4b0de..90839ec9ccf4 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -113,8 +113,6 @@ def __init__(self, moe: FusedMoEConfig): self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() - self.device_support_pdl = current_platform.is_cuda( - ) and current_platform.has_device_capability(90) if current_platform.is_device_capability(100) and not has_flashinfer(): logger.warning_once( From 7319792eb6a41435ed45f68dea189da5dcb2bb2d Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 21 Aug 2025 14:29:55 -0700 Subject: [PATCH 7/9] update dockerfile Signed-off-by: Siyuan Fu --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index cfaa59868215..2e037a7bde27 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -373,7 +373,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" # Keep this in sync with "flashinfer" extra in setup.py -ARG FLASHINFER_GIT_REF="v0.2.12" +ARG FLASHINFER_GIT_REF="v0.2.13" # Flag to control whether to compile FlashInfer AOT kernels # Set to "true" to enable AOT compilation: # docker build --build-arg FLASHINFER_AOT_COMPILE=true ... From 9988632fd25a2d461f338bede8ec58c41584fc83 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Fri, 22 Aug 2025 09:22:38 -0700 Subject: [PATCH 8/9] address todo Signed-off-by: Siyuan Fu --- vllm/model_executor/layers/quantization/mxfp4.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 90839ec9ccf4..f719af76cc69 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter from vllm import envs +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) @@ -113,6 +114,8 @@ def __init__(self, moe: FusedMoEConfig): self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() + self.max_captute_size = get_current_vllm_config( + ).compilation_config.max_capture_size if current_platform.is_device_capability(100) and not has_flashinfer(): logger.warning_once( @@ -551,7 +554,7 @@ def apply( 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize # TODO: use the maximum number in the cudagraph_batch_sizes - tune_max_num_tokens=8192, + tune_max_num_tokens=self.max_captute_size, )[0] return trtllm_gen_output else: From ba9b2ea2281a89b213a6684f9866a5fce7a70f89 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Fri, 22 Aug 2025 09:24:00 -0700 Subject: [PATCH 9/9] address todo Signed-off-by: Siyuan Fu --- vllm/model_executor/layers/quantization/mxfp4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index f719af76cc69..354c715c2bec 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -553,7 +553,6 @@ def apply( self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize - # TODO: use the maximum number in the cudagraph_batch_sizes tune_max_num_tokens=self.max_captute_size, )[0] return trtllm_gen_output