Skip to content

Commit 8bb0425

Browse files
xxi-nvusberkeley
authored andcommitted
[None][fix] fix a bug: deepseek_fp8_block_scales in TRTLLMGEN-MoE use 2D x_sf instead of 1D (NVIDIA#9658)
Signed-off-by: xxi <xxi@nvidia.com>
1 parent 387c1de commit 8bb0425

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,12 @@ def quantize_input(self, x, post_quant_comm: bool = True):
332332
x, False, alignment=self.quant_method.input_hidden_alignment)
333333
x_row, x_col = x.shape[0], x.shape[1]
334334
elif self.has_deepseek_fp8_block_scales:
335-
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
336-
x_row = x.shape[0]
335+
# For SM100+, fp8_quantize_1x128 returns x_sf with shape (blocked_n, num_tokens),
336+
# but moe_a2a_dispatch requires all payloads to have first dim = num_tokens.
337+
# Transpose x_sf before dispatch and transpose back after receive, but this may
338+
# introduce perf regression. So we don't supports post_quant_comm for fp8_block_scales.
339+
# TODO: Consider remove the constraint of the OneSided AlltoAll
340+
pass
337341
elif self.has_w4a16_mxfp4:
338342
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
339343
x = torch.nn.functional.pad(x, (0, pad_size))
@@ -412,6 +416,9 @@ def run_moe(
412416

413417
if self.has_deepseek_fp8_block_scales:
414418
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
419+
# fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
420+
if x_sf is None:
421+
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
415422

416423
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
417424
router_logits,

0 commit comments

Comments
 (0)