Skip to content

Commit 35a24b3

Browse files
committed
fix after flashinfer autotuner
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 3128240 commit 35a24b3

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def __init__(self, moe: FusedMoEConfig):
113113
self.topk_indices_dtype = None
114114
self.moe = moe
115115
self.use_marlin = self._should_use_marlin()
116+
self.device_support_pdl = current_platform.is_cuda(
117+
) and current_platform.has_device_capability(90)
116118

117119
if current_platform.is_device_capability(100) and not has_flashinfer():
118120
logger.warning_once(
@@ -520,7 +522,8 @@ def apply(
520522
x_scale = None
521523
else:
522524
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
523-
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
525+
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
526+
*x.shape[:-1], -1)
524527
trtllm_gen_output = trtllm_fp4_block_scale_moe(
525528
router_logits.to(torch.bfloat16),
526529
None, # routing_bias
@@ -549,6 +552,10 @@ def apply(
549552
self._get_tile_tokens_dim(x, top_k),
550553
1 if renormalize else 0, # routing_method_type, renormalize
551554
True, # do finalize
555+
self.device_support_pdl,
556+
None, # output
557+
# TODO: use the maximum number in the cudagraph_batch_sizes
558+
8192, # tune_max_num_tokens.
552559
)[0]
553560
return trtllm_gen_output
554561
else:

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def compile_or_warm_up_model(self) -> None:
312312
logger.info("Compile and warming up model for size %d", size)
313313
self.model_runner._dummy_run(size, skip_eplb=True)
314314

315+
# run autotuner before cuda graph capture.
316+
kernel_warmup(self)
317+
315318
if not self.model_config.enforce_eager:
316319
self.model_runner.capture_model()
317320

@@ -336,9 +339,6 @@ def compile_or_warm_up_model(self) -> None:
336339
self.model_runner._dummy_sampler_run(
337340
hidden_states=last_hidden_states)
338341

339-
# Warmup kernels used during model execution
340-
kernel_warmup(self)
341-
342342
# Reset the seed to ensure that the random state is not affected by
343343
# the model initialization and profiling.
344344
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)