File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed
model_executor/layers/quantization Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -113,6 +113,8 @@ def __init__(self, moe: FusedMoEConfig):
113
113
self .topk_indices_dtype = None
114
114
self .moe = moe
115
115
self .use_marlin = self ._should_use_marlin ()
116
+ self .device_support_pdl = current_platform .is_cuda (
117
+ ) and current_platform .has_device_capability (90 )
116
118
117
119
if current_platform .is_device_capability (100 ) and not has_flashinfer ():
118
120
logger .warning_once (
@@ -520,7 +522,8 @@ def apply(
520
522
x_scale = None
521
523
else :
522
524
x_quant , x_scale = mxfp8_quantize (x , False ) # to mxfp8
523
- x_scale = x_scale .view (torch .float8_e4m3fn ).reshape (- 1 )
525
+ x_scale = x_scale .view (torch .float8_e4m3fn ).reshape (
526
+ * x .shape [:- 1 ], - 1 )
524
527
trtllm_gen_output = trtllm_fp4_block_scale_moe (
525
528
router_logits .to (torch .bfloat16 ),
526
529
None , # routing_bias
@@ -549,6 +552,10 @@ def apply(
549
552
self ._get_tile_tokens_dim (x , top_k ),
550
553
1 if renormalize else 0 , # routing_method_type, renormalize
551
554
True , # do finalize
555
+ self .device_support_pdl ,
556
+ None , # output
557
+ # TODO: use the maximum number in the cudagraph_batch_sizes
558
+ 8192 , # tune_max_num_tokens.
552
559
)[0 ]
553
560
return trtllm_gen_output
554
561
else :
Original file line number Diff line number Diff line change @@ -312,6 +312,9 @@ def compile_or_warm_up_model(self) -> None:
312
312
logger .info ("Compile and warming up model for size %d" , size )
313
313
self .model_runner ._dummy_run (size , skip_eplb = True )
314
314
315
+ # run autotuner before cuda graph capture.
316
+ kernel_warmup (self )
317
+
315
318
if not self .model_config .enforce_eager :
316
319
self .model_runner .capture_model ()
317
320
@@ -336,9 +339,6 @@ def compile_or_warm_up_model(self) -> None:
336
339
self .model_runner ._dummy_sampler_run (
337
340
hidden_states = last_hidden_states )
338
341
339
- # Warmup kernels used during model execution
340
- kernel_warmup (self )
341
-
342
342
# Reset the seed to ensure that the random state is not affected by
343
343
# the model initialization and profiling.
344
344
set_random_seed (self .model_config .seed )
You can’t perform that action at this time.
0 commit comments