-
Notifications
You must be signed in to change notification settings - Fork 120
Description
Hi team,
We are currently running comprehensive benchmarks but have been unable to run DeepEP low-latency mode successfully on P5 (I'm still troubleshooting and hope that team can give some suggestions). The issue appears to be a race condition that causes vllm serve to hang during initialization, typically before CUDA Graph capture. We consider the low-latency mode particularly important because it supports CUDA Graph, which our profiling shows delivers significant decode performance improvements (~10× latency reduction).
Additionally, we believe the numbers reported in the vLLM Serving Benchmark Results may be incorrect. It appears that allgather_reducescatter is not using the OFI-NCCL plugin correctly, causing it to fall back to socket transport instead of EFA. The reported numbers are consistent with what we observed in earlier experiments when OFI-NCCL was accidentally linked to the wrong path. This can be verified by enabling NCCL_DEBUG=INFO.
# wrong
NCCL INFO Using network socket
# correct
NCCL INFO Using network Libfabric
The following sections are our experiments and results
Experiment
scripts are here: https://github.com/crazyguitar/pysheeet/tree/master/src/llm/vllm
# launch a vLLM server on p5.48xlarge
salloc -N 4 bash run.sbatch "deepseek-ai/DeepSeek-V3-0324" \
--image ${PWD}/images/vllm.tar.gz \
--all2all-backend allgather_reducescatter \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--gpu-memory-utilization 0.8
# launch a client to benchmark
salloc -N 1 bash bench.sh -H <HEAD_NODE_IP> -- \
--dataset-name random \
--num-prompts 1000 \
--random-input-len 1024 \
--random-output-len 256 \
--request-rate 10 \
--max-concurrency 256 \
--ignore-eos
Benchmark Results
| Metric | allgather_reducescatter | deepep_high_throughput |
|---|---|---|
| Request throughput (req/s) | 7.15 | 2.81 |
| Output token throughput (tok/s) | 1,829.55 | 719.02 |
| Total token throughput (tok/s) | 9,140.63 | 3,592.30 |
| Mean TTFT (ms) | 4,917.28 | 6,683.66 |
| Mean TPOT (ms) | 106.61 | 307.85 |
| P99 ITL (ms) | 1,525.67 | 4,111.62 |
Benchmark Analysis
- CUDA Graph incompatibility is the primary bottleneck
DeepEP's high-throughput MoE kernels cannot be captured by CUDA Graph. When CUDA Graph is enabled, most MoE layers in the DeepEP backend remain uncaptured and fall back to eager mode, forfeiting the ~10× decode latency reduction that CUDA Graph provides (e.g., 85ms → ~8ms per decode step).
The allgather_reducescatter backend fully supports CUDA Graph capture, which is why it dominates in end-to-end performance despite having comparable per-layer dispatch/combine latencies.
- Dispatch/combine latencies are NOT the bottleneck
Profiling Qwen2-57B-A14B in eager mode shows both backends have similar total forward times (~2ms vs ~1.9ms per MoE layer):
- deepep_high_throughput: dispatch ~242µs, combine ~80µs
- allgather_reducescatter: dispatch ~200µs, combine ~161µs
- Naive All2All backend confirms EFA small-write weakness
We also benchmarked vLLM's naive All2All backend (all backends in eager mode for fair comparison):
| Backend | Nodes | Req/s | Output tok/s | TTFT (ms) | ITL (ms) |
|---|---|---|---|---|---|
| deepep_high_throughput | 4 | 2.77 | 1,418 | 19,032 | 292 |
| deepep_high_throughput | 8 | 2.76 | 1,412 | 16,572 | 310 |
| allgather_reducescatter | 4 | 2.29 | 1,175 | 18,533 | 341 |
| allgather_reducescatter | 8 | 3.28 | 1,681 | 27,781 | 233 |
| naive (all2all) | 4 | 1.17 | 598 | 46,509 | 727 |
| naive (all2all) | 8 | 0.61 | 310 | 75,831 | 1,443 |
The naive backend is ~2–5× slower because it issues many small NCCL broadcast calls, triggering EFA's known poor performance with high volumes of small writes. The allgather_reducescatter backend avoids this by consolidating communication into single large NCCL operations.
Conclusion
- An effective MoE All2All kernel should support CUDA Graph capture and minimize small write operations.
- EFA performs poorly under high volumes of small write operations.
- allgather_reducescatter remains the best-performing backend on EFA clusters due to its CUDA Graph compatibility and consolidated communication pattern.
Environment
- Cluster: p5 (2–4 nodes, 8× H100 per node)
- Network: EFA
- Model: DeepSeek-V3-0324, Qwen2-57B-A14B
- Framework: vLLM