Skip to content

Support NVFP4 masked layout MoE #7994

@fzyzcjy

Description

@fzyzcjy

Currently, for large scale EP, the MoE GEMM is fp8 via DeepGEMM Blackwell. It would be great to use nvfp4 for MoE GEMM to speedup the system.

More specifically:

  • We will need a up&gate GEMM, an activation & quant operation, and a down GEMM.
  • The input of the subsystem is the output of DeepEP's Buffer.low_latency_dispatch, i.e. we will have the following
    • recv_x (hidden states): [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden] in NVFP4 format.
    • recv_count: a tensor shaped [num_local_experts] with type torch.int, indicating how many tokens each expert receives. As mentioned before, not all tokens are valid in recv_x.
  • The output of the system will be the input of DeepEP Buffer.low_latency_combine, i.e.
    • x (hidden states): [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden] in bf16 format

Break down into actionable items (mainly for nv team):

  • DeepEP fp4 dispatch
  • A masked-layout fp4 gemm with sota speed
    • Implementation: feat: masked layout fp4 gemm using cute-dsl flashinfer-ai/flashinfer#1331 (@yzh119 @yyihuang; update: already done)
    • Integration: [NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked  #9199 (potential code change: copy the gemm's UT as a weight loader, and do some wiring) @wenscarl
    • where to put it: since there is ongoing refactors, maybe we just put it firstly to somewhere like the existing trtllm-gen fp4 code path.
      • (the codebase has changed a lot and this may not be useful ref) this may be a reference to see how to use DeepEP together with existing fp4 kernels as well: Support DeepEP communication for nvfp4 moe (+12% e2e) #8376.
      • new suggestion: maybe follow the standard "DeepEP low latency + fp8 DeepGEMM" code path (--enable-deepep-moe), but change the execution and weight-loading code of "fp8 DeepGEMM" into the new "fp4 cutedsl gemm"
    • To test the correctness of this part without "DeepEP fp4 dispatch", you may want to set use_fp8=False in _DeepEPDispatcherImplLowLatency.dispatch_a, which will make deepep provide bf16 data (instead of the unwanted fp8 data), then we can manually apply a quant-to-fp4 kernel
    • To test the correctness of this part without the new "act-and-quant kernel", you may use any standard activation kernels and fp4 quant kernels that are not fused and do not understand the masked-layout. It is very slow but can work for correctness.
    • Example command: python3 -m sglang.launch_server --model-path DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 4 --dp 4 --enable-dp-attention --enable-flashinfer-moe --enable-deepep-moe --deepep-mode low_latency (remark: this will not work for today's code b/c it enables low-latency DeepEP + flashinfer moe which is incompatible. maybe e.g. you will have --enable-cutedsl-moe instead)
  • A masked-layout fp4 activation & quantization kernel @kaixih

The points do not have strict dependency relationship. For example, if the fp4 gemm does not arrive, we can still do all other points by temporarily using trtllm-gen gemm (or cutlass gemm) with a temporary inefficient layout-permuting glue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions