-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Open
Description
Currently, for large scale EP, the MoE GEMM is fp8 via DeepGEMM Blackwell. It would be great to use nvfp4 for MoE GEMM to speedup the system.
More specifically:
- We will need a up&gate GEMM, an activation & quant operation, and a down GEMM.
- The input of the subsystem is the output of DeepEP's
Buffer.low_latency_dispatch
, i.e. we will have the followingrecv_x
(hidden states):[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]
in NVFP4 format.recv_count
: a tensor shaped[num_local_experts]
with typetorch.int
, indicating how many tokens each expert receives. As mentioned before, not all tokens are valid inrecv_x
.
- The output of the system will be the input of DeepEP
Buffer.low_latency_combine
, i.e.x
(hidden states):[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]
in bf16 format
Break down into actionable items (mainly for nv team):
- DeepEP fp4 dispatch
- Implementation: Support nvfp4 low latency mode dispatch deepseek-ai/DeepEP#341 (@shifangx) (the PR may not be merged soon, but they say feature is done; they have done a init version, we can verify and see if there are any issues)
- Integration (potential code change: just enable one flag)
- A masked-layout fp4 gemm with sota speed
- Implementation: feat: masked layout fp4 gemm using cute-dsl flashinfer-ai/flashinfer#1331 (@yzh119 @yyihuang; update: already done)
- Integration: [NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked #9199 (potential code change: copy the gemm's UT as a weight loader, and do some wiring) @wenscarl
- where to put it: since there is ongoing refactors, maybe we just put it firstly to somewhere like the existing trtllm-gen fp4 code path.
- (the codebase has changed a lot and this may not be useful ref)
this may be a reference to see how to use DeepEP together with existing fp4 kernels as well: Support DeepEP communication for nvfp4 moe (+12% e2e) #8376. - new suggestion: maybe follow the standard "DeepEP low latency + fp8 DeepGEMM" code path (--enable-deepep-moe), but change the execution and weight-loading code of "fp8 DeepGEMM" into the new "fp4 cutedsl gemm"
- (the codebase has changed a lot and this may not be useful ref)
- To test the correctness of this part without "DeepEP fp4 dispatch", you may want to set
use_fp8=False
in_DeepEPDispatcherImplLowLatency.dispatch_a
, which will make deepep provide bf16 data (instead of the unwanted fp8 data), then we can manually apply a quant-to-fp4 kernel - To test the correctness of this part without the new "act-and-quant kernel", you may use any standard activation kernels and fp4 quant kernels that are not fused and do not understand the masked-layout. It is very slow but can work for correctness.
- Example command:
python3 -m sglang.launch_server --model-path DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 4 --dp 4 --enable-dp-attention --enable-flashinfer-moe --enable-deepep-moe --deepep-mode low_latency
(remark: this will not work for today's code b/c it enables low-latency DeepEP + flashinfer moe which is incompatible. maybe e.g. you will have--enable-cutedsl-moe
instead)
- A masked-layout fp4 activation & quantization kernel @kaixih
- Implementation (potential code change: look at the existing masked-layout fp8 one (https://github.com/sgl-project/sglang/pull/7601/files#diff-283bfa4a864c00e9e4ff3d29d32d8ec7bf2ebed1afe61047c4c77ceca6247d7e) and the cont-layout fp4 one (e.g. the one in flashinfer trtllm-gen moe) and combine the two)
- Integration (potential code change: simply call that stateless function)
The points do not have strict dependency relationship. For example, if the fp4 gemm does not arrive, we can still do all other points by temporarily using trtllm-gen gemm (or cutlass gemm) with a temporary inefficient layout-permuting glue.
Swipe4057, zhyncs, yyihuang, truecrab, shifangx and 2 more
Metadata
Metadata
Assignees
Labels
No labels