@@ -332,8 +332,12 @@ def quantize_input(self, x, post_quant_comm: bool = True):
332332 x , False , alignment = self .quant_method .input_hidden_alignment )
333333 x_row , x_col = x .shape [0 ], x .shape [1 ]
334334 elif self .has_deepseek_fp8_block_scales :
335- x , x_sf = torch .ops .trtllm .fp8_quantize_1x128 (x )
336- x_row = x .shape [0 ]
335+ # For SM100+, fp8_quantize_1x128 returns x_sf with shape (blocked_n, num_tokens),
336+ # but moe_a2a_dispatch requires all payloads to have first dim = num_tokens.
337+ # Transpose x_sf before dispatch and transpose back after receive, but this may
338+ # introduce perf regression. So we don't supports post_quant_comm for fp8_block_scales.
339+ # TODO: Consider remove the constraint of the OneSided AlltoAll
340+ pass
337341 elif self .has_w4a16_mxfp4 :
338342 pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
339343 x = torch .nn .functional .pad (x , (0 , pad_size ))
@@ -412,6 +416,9 @@ def run_moe(
412416
413417 if self .has_deepseek_fp8_block_scales :
414418 assert do_finalize , "fp8_block_scale_moe_runner does not support do_finalize=False"
419+ # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
420+ if x_sf is None :
421+ x , x_sf = torch .ops .trtllm .fp8_quantize_1x128 (x )
415422
416423 final_hidden_states = torch .ops .trtllm .fp8_block_scale_moe_runner (
417424 router_logits ,
0 commit comments