Skip to content

Commit 00e75ef

Browse files
committed
fix after flashinfer autotuner
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 0e3bb54 commit 00e75ef

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(self, moe: FusedMoEConfig):
8686
self.topk_indices_dtype = None
8787
self.moe = moe
8888
self.use_marlin = self._should_use_marlin()
89+
self.device_support_pdl = current_platform.is_cuda(
90+
) and current_platform.has_device_capability(90)
8991

9092
def _should_use_marlin(self):
9193
if envs.VLLM_MXFP4_USE_MARLIN is not None:
@@ -488,7 +490,8 @@ def apply(
488490
x_scale = None
489491
else:
490492
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
491-
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
493+
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
494+
*x.shape[:-1], -1)
492495
trtllm_gen_output = trtllm_fp4_block_scale_moe(
493496
router_logits.to(torch.bfloat16),
494497
None, # routing_bias
@@ -517,6 +520,10 @@ def apply(
517520
self._get_tile_tokens_dim(x, top_k),
518521
1 if renormalize else 0, # routing_method_type, renormalize
519522
True, # do finalize
523+
self.device_support_pdl,
524+
None, # output
525+
# TODO: use the maximum number in the cudagraph_batch_sizes
526+
8192, # tune_max_num_tokens.
520527
)[0]
521528
return trtllm_gen_output
522529
else:

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,9 @@ def compile_or_warm_up_model(self) -> None:
311311
logger.info("Compile and warming up model for size %d", size)
312312
self.model_runner._dummy_run(size, skip_eplb=True)
313313

314+
# run autotuner before cuda graph capture.
315+
kernel_warmup(self)
316+
314317
if not self.model_config.enforce_eager:
315318
self.model_runner.capture_model()
316319

@@ -335,9 +338,6 @@ def compile_or_warm_up_model(self) -> None:
335338
self.model_runner._dummy_sampler_run(
336339
hidden_states=last_hidden_states)
337340

338-
# Warmup kernels used during model execution
339-
kernel_warmup(self)
340-
341341
# Reset the seed to ensure that the random state is not affected by
342342
# the model initialization and profiling.
343343
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)