- Overview
- Hardware Requirements
- NVFP4 Quantization
- Why SGLang Only (vLLM Does Not Work)
- BF16 KV Cache Mandatory
- NCCL Environment Variables
- Docker Images
- SGLang Launch Commands
- MTP / Speculative Decoding
- FlashInfer CUTLASS Race Condition Fix
- Power Consumption
- Benchmark Results
- Memory Usage
- TP/PP Configurations
- SM120 Architecture Limitations
- All Errors and Fixes
- Related PRs
| Parameter | Value |
|---|---|
| Model | zai-org/GLM-5 |
| Total parameters | 744B |
| Active parameters | 40B |
| Architecture | MoE with DeepSeek Sparse Attention (DSA), MLA-based |
| Experts | 256 total, 8 activated per token |
| MTP layer | Layer 78 (~19 GB in BF16 precision) |
| SWE-bench Verified | 77.8 (vs Qwen 72.0) |
| Inference engine | SGLang only (vLLM does not work on SM120) |
| Minimum GPUs | 8x RTX PRO 6000 (768 GB VRAM) |
GLM-5 is a 744B MoE model with DeepSeek Sparse Attention. On SM120 (RTX PRO 6000 Blackwell), SGLang bypasses all DSA backends and runs GLM-5 as if it were a DeepSeek V3.1 model -- using MLA kernels that ignore the sparsity mask. This is "backwards compatible" since the training-time indexer would have masked out irrelevant tokens, so computing full attention is slightly wasteful but not accuracy-degrading.
- NVFP4 weights: ~440 GB (57.06 GB per GPU across 8 GPUs)
- Cannot fit on 4x RTX PRO 6000 (only 384 GB total VRAM)
- Minimum viable: 6x RTX PRO 6000 using
--tp 2 --pp 3
| Component | Details |
|---|---|
| GPUs | 8x NVIDIA RTX PRO 6000 Blackwell 96 GB (SM120) |
| Total VRAM | 768 GB |
| RAM | 1.5 TB recommended |
| CPU topology | 2x NUMA nodes: GPU0-3 on NUMA0, GPU4-7 on NUMA1 |
| Tested CPUs | Genoa (EPYC 9004) and Turin (EPYC 9005) |
| Driver | 590.48.01 (CUDA 13.1) |
| Checkpoint | MTP | Disk Size | Notes |
|---|---|---|---|
lukealonso/GLM-5-NVFP4 |
No | ~410 GB | Original quant, no MTP weights |
festr2/GLM-5-NVFP4-MTP |
Yes (BF16) | ~410 GB + 19 GB | MTP layer 78 restored from BF16 checkpoint |
QuantTrio/GLM-5-AWQ |
-- | ~420 GB | Fails with OOM during weight loading; NVFP4 is superior |
- 4-bit blockwise with FP8 scales via NVIDIA Model Optimizer
- SGLang flag:
--quantization modelopt_fp4 - VRAM for weights: ~57 GB per GPU on TP8
- MMLU accuracy: 0.873 (official BF16 benchmark: 0.877, gap of only -0.004)
~10% accuracy drops observed at 100K+ context lengths in MMLU testing.
GLM-5 does NOT work on vLLM for SM120 as of 2026-03-08.
The error:
ValueError: No valid attention backend found for cuda with
AttentionSelectorConfig(head_size=576, dtype=torch.bfloat16, kv_cache_dtype=auto,
use_mla=True, use_sparse=True, ...)
Root causes:
- No vLLM attention backend supports MLA + sparse attention + SM120 simultaneously
- GLM-5 uses
qk_nope_head_dim == 192(FlashInfer MLA requires 128) - NVFP4 support keeps breaking in vLLM
SGLang works by bypassing all DSA (DeepSeek Sparse Attention) backends entirely and running GLM-5 in non-DSA mode using FlashInfer FA2 MLA kernels that are SM120-compatible.
Grimulkan has a plan to port GLM-5 to vLLM:
- Pull FlashInfer FA2 bf16 and XQA fp8 MLA kernels from SGLang into vLLM
- Wire GLM-5 in non-DSA mode
- Fix NVFP4 MoE GEMM + DCP compatibility
- Use normal FA2 for prefill
- Enable MTP head (already exists in vLLM)
FP8 KV cache (--kv-cache-dtype fp8_e4m3) does NOT work on SM120. It produces garbled output or emits 1 token and stops.
The root cause is that luke had a local patch for KV scales in the FlashInfer backend (passing FP8 dequantization scales in the ragged+paged split path). Without those scales, the cached KV prefix is read back without undoing the scale division.
Always use:
--kv-cache-dtype bf16This limits practical context to ~200K tokens (vs potentially more with FP8), but is the only working option.
export NCCL_IB_DISABLE=1 # No InfiniBand
export NCCL_P2P_LEVEL=SYS # or PHB for same-NUMA only
export NCCL_ALLOC_P2P_NET_LL_BUFFERS=1
export NCCL_MIN_NCHANNELS=8wget https://www.voipmonitor.org/nccl_graph_opt.xml -O /mnt/nccl_graph_opt.xml
export NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xmlThis tricks NCCL into using the low-latency (LL) protocol for small messages across NUMA nodes. Measured +11% throughput improvement on Genoa with 2 NUMA nodes and 4 GPUs per node.
Alternative (simpler but less optimal): export NCCL_PROTO=LL
export OMP_NUM_THREADS=8
export SAFETENSORS_FAST_GPU=1
export NVIDIA_TF32_OVERRIDE=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export FLASHINFER_DISABLE_VERSION_CHECK=1
export NCCL_CUMEM_HOST_ENABLE=0
# Critical for GLM-5:
export SGLANG_ENABLE_JIT_DEEPGEMM=0 # DeepGemm not supported on SM120
export SGLANG_ENABLE_DEEP_GEMM=0 # Fully disable DeepGemm fallback
export SGLANG_ENABLE_SPEC_V2=True # MANDATORY for MTP (see MTP section)If NCCL P2P hangs occur:
iommu=pt
amd_iommu=pt # on AMD platforms
docker pull voipmonitor/llm-pytorch-blackwell:nightlyCustom image by Festr containing:
- SGLang compiled from source with SM120 patches
- PyTorch 2.12, latest FlashInfer, CUTLASS 4.4.1, cuDNN 91901
- SM_120f compilation target enabled
- Pre-generated Triton MoE kernel configs for RTX PRO 6000 Blackwell Server Edition
docker run -it --rm \
--entrypoint /bin/bash \
--gpus all \
--ipc=host \
--shm-size=8g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
--network host \
--cpuset-cpus "0-63" \
-v /root/.cache/huggingface:/root/.cache/huggingface \
-v /mnt:/mnt \
-v vllm-nightly-jit:/cache/jit \
voipmonitor/llm-pytorch-blackwell:nightly| Image | Notes |
|---|---|
lmsysorg/sglang:dev-cu13 |
Official SGLang nightly, CUDA 13.0. Needs pip install --upgrade transformers inside. |
lmsysorg/sglang:glm5-blackwell |
Official GLM5-specific image. Built for SM90/SM100 -- broken on SM120, use voipmonitor or dev-cu13 instead. |
voipmonitor/llm-pytorch-blackwell:nightly-fp4-prezero |
Experimental build with FlashInfer pre-zero fix. |
# sglang dev-cu13 nightly pulled 2026-03-04
FROM lmsysorg/sglang@sha256:426d1fa4b10722688678b99d817c2caa92a89eed4a8ee2927ab44a848bbe77df
RUN pip install --no-cache-dir transformers==5.2.0
# Fix DeepGemm scale format detection for NVFP4 models on Blackwell (SM120)
# NVFP4 uses float8_e4m3fn scales, not ue8m0 -- hardcoded True causes NaN
RUN sed -i "s/DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL/DEEPGEMM_SCALE_UE8M0 = False/" \
/sgl-workspace/sglang/python/sglang/srt/layers/deep_gemm_wrapper/configurer.pySGLANG_ENABLE_SPEC_V2=True \
SGLANG_ENABLE_JIT_DEEPGEMM=0 \
SGLANG_ENABLE_DEEP_GEMM=0 \
NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml \
NCCL_IB_DISABLE=1 \
NCCL_P2P_LEVEL=SYS \
NCCL_ALLOC_P2P_NET_LL_BUFFERS=1 \
NCCL_MIN_NCHANNELS=8 \
OMP_NUM_THREADS=8 \
SAFETENSORS_FAST_GPU=1 \
python3 -m sglang.launch_server \
--model-path /mnt/GLM-5-NVFP4-MTP \
--tp 8 \
--trust-remote-code \
--attention-backend flashinfer \
--moe-runner-backend cutlass \
--kv-cache-dtype bf16 \
--tool-call-parser glm47 \
--reasoning-parser glm45 \
--quantization modelopt_fp4 \
--disable-custom-all-reduce \
--enable-flashinfer-allreduce-fusion \
--mem-fraction-static 0.85 \
--cuda-graph-max-bs 32 \
--host 0.0.0.0 \
--port 5000 \
--served-model-name glm-5 \
--max-running-requests 64 \
--model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
--speculative-algorithm NEXTN \
--speculative-num-steps 3 \
--speculative-num-draft-tokens 4 \
--speculative-eagle-topk 1 \
--enable-metricsNCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml \
NCCL_IB_DISABLE=1 \
NCCL_P2P_LEVEL=SYS \
NCCL_ALLOC_P2P_NET_LL_BUFFERS=1 \
NCCL_MIN_NCHANNELS=8 \
OMP_NUM_THREADS=8 \
SAFETENSORS_FAST_GPU=1 \
python3 -m sglang.launch_server \
--model-path lukealonso/GLM-5-NVFP4 \
--tp 8 \
--trust-remote-code \
--attention-backend flashinfer \
--moe-runner-backend flashinfer_cutlass \
--kv-cache-dtype bf16 \
--tool-call-parser glm47 \
--reasoning-parser glm45 \
--quantization modelopt_fp4 \
--disable-custom-all-reduce \
--enable-flashinfer-allreduce-fusion \
--mem-fraction-static 0.9 \
--cuda-graph-max-bs 8 \
--host 0.0.0.0 \
--port 5000 \
--served-model-name glm-5 \
--max-running-requests 8 \
--model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'services:
sglang-glm5:
build: .
image: sglang-glm5:latest
container_name: sglang-glm5-nightly
runtime: nvidia
environment:
- NVIDIA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
- CUDA_DEVICE_ORDER=PCI_BUS_ID
- NCCL_IB_DISABLE=1
- NCCL_P2P_LEVEL=SYS
- NCCL_ALLOC_P2P_NET_LL_BUFFERS=1
- NCCL_MIN_NCHANNELS=8
- OMP_NUM_THREADS=8
- SAFETENSORS_FAST_GPU=1
- NCCL_CUMEM_HOST_ENABLE=0
- FLASHINFER_DISABLE_VERSION_CHECK=1
- PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
volumes:
- /mnt/raid0/models:/models:ro
- huggingface-cache:/root/.cache/huggingface
ports:
- "8003:5000"
command:
- python3
- -m
- sglang.launch_server
- --model-path=/models/festr2/GLM-5-NVFP4-MTP
- --served-model-name=glm-5
- --reasoning-parser=glm45
- --tool-call-parser=glm47
- --trust-remote-code
- --tp=8
- --mem-fraction-static=0.9
- --max-running-requests=64
- --kv-cache-dtype=bf16
- --quantization=modelopt_fp4
- --attention-backend=flashinfer
- --moe-runner-backend=deep_gemm
- --disable-custom-all-reduce
- --cuda-graph-max-bs=32
- --host=0.0.0.0
- --port=5000
- '--model-loader-extra-config={"enable_multithread_load": true, "num_threads": 8}'
- --speculative-algorithm=EAGLE
- --speculative-num-steps=3
- --speculative-eagle-topk=1
- --speculative-num-draft-tokens=4
cpuset: "0-63"
ipc: host
shm_size: "8g"
ulimits:
memlock: -1
stack: 67108864| Parameter | Reason |
|---|---|
--quantization modelopt_fp4 |
Required for NVFP4 checkpoint |
--kv-cache-dtype bf16 |
Mandatory on SM120 -- fp8_e4m3 produces garbled output |
--tp 8 |
All 8 GPUs required; model is 57 GB/GPU before KV cache |
--attention-backend flashinfer |
Architecture-independent; flashmla/trtllm are SM90/SM100 only |
--moe-runner-backend cutlass |
Fastest for MTP speculative decoding |
--disable-custom-all-reduce |
Custom allreduce is optimized for NVLink; PCIe only on RTX PRO |
--enable-flashinfer-allreduce-fusion |
Fuses allreduce with attention -- measurable throughput gain |
--mem-fraction-static 0.85-0.92 |
Leave 7-15 GB for CUDA workspace per GPU |
SGLANG_ENABLE_JIT_DEEPGEMM=0 |
DeepGemm not supported on SM120 |
SGLANG_ENABLE_DEEP_GEMM=0 |
Fully disables DeepGemm fallback path |
SGLANG_ENABLE_SPEC_V2=True |
Critical for MTP -- without it, NEXTN falls back to EAGLE and loads model twice (OOM) |
# Environment variable (MANDATORY):
SGLANG_ENABLE_SPEC_V2=True
# Launch flags:
--speculative-algorithm NEXTN
--speculative-num-steps 3
--speculative-num-draft-tokens 4
--speculative-eagle-topk 1WARNING: SGLANG_ENABLE_SPEC_V2=True is mandatory. Without it, SGLang silently converts NEXTN to EAGLE and loads the full model a second time as a draft model -- instant OOM (57 GB x 2 = 114 GB per GPU on a 96 GB card).
- Use:
festr2/GLM-5-NVFP4-MTP(HuggingFace) - Created by Festr by restoring MTP heads from BF16 checkpoint to the NVFP4 quant
- MTP layer is layer 78, kept in BF16 precision (~19 GB)
- FP8 MTP is possible but not recommended (decreases accept rate)
- The original
lukealonso/GLM-5-NVFP4does not include MTP weights
MTP roughly doubles throughput over the non-MTP baseline:
- Accept rate: 0.55-0.94 (varies by context)
- Accept length: 2.19-2.80 tokens
- Without MTP: 35-50 tok/s
- With MTP: 70-105 tok/s
| Backend | Performance | Notes |
|---|---|---|
--moe-runner-backend cutlass |
Fastest | Best for MTP speculative decoding |
--moe-runner-backend flashinfer_cutlass |
Slightly slower | Default fallback |
--moe-runner-backend deep_gemm |
Falls back to cutlass | DeepGemm not supported on SM120; misleading in logs |
A race condition in the FlashInfer CUTLASS FP4 GEMM kernel produces NaN values, causing crashes.
/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:112: _assert_async_cuda_kernel:
Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.
Or: CUDA device-side assert triggered in eagle_worker_v2.py:510 _zero_fill_draft_kv_for_cached_prefix
FlashInfer CUTLASS FP4 GEMM kernel race condition. Fix: flashinfer-ai/flashinfer#2716
- Upgrade to CUTLASS 4.4.1 and rebuild FlashInfer JIT cache (
rm -rf /cache/jit/*). Usevoipmonitor/llm-pytorch-blackwell:nightlywhich includes this fix. - Use
--fp4-gemm-backend flashinfer_cudnninstead of flashinfer_cutlass - Use
--enable-nan-detection(prevents crash but may produce garbage tokens) - Apply luke's sampler patch (validates/fixes probabilities before multinomial sampling)
Important: When upgrading Docker images, the old JIT kernel cache must be wiped for the fix to take effect:
rm -rf /cache/jit/*GLM-5 draws significantly more power than other models:
| Phase | Power per Card | Notes |
|---|---|---|
| Decode | ~300W | Sustained |
| Prefill | 400-600W | Peaking at 640W observed |
| Prefill (all 8 cards) | 600W each | All cards hit 600W simultaneously |
Plan cooling and PSU capacity accordingly. An 8-GPU setup draws up to 4,800W from GPUs alone during prefill.
| Configuration | 0 Context | 15K Context | 100K Context | 200K Context |
|---|---|---|---|---|
| NVFP4, no MTP (luke, early) | ~50 | -- | -- | -- |
| NVFP4, no MTP (Festr/JTazz) | 35-44 | 30 | -- | -- |
| NVFP4 + MTP (EAGLE) | 70-105 | -- | 60-80 | -- |
| NVFP4 + MTP (latest, Festr) | ~100 | -- | 60-80 | ~50 |
| NVFP4 + MTP (orangezed) | 97.2 | -- | -- | -- |
- 3 running requests: 133-135 tok/s generation throughput (accept rate 0.55-0.70)
- Accept length: 2.19-2.80 tokens
- Single batch prefill: ~4,000 tok/s
| Phase | Duration |
|---|---|
| Model load (multithread, 8-16 threads) | ~36 seconds |
| CUDA graph capture | ~208 seconds |
| Total startup | ~7-8 minutes |
FP8 KV cache: 20 tok/s (vs 90 tok/s with bf16 KV cache) -- confirmed broken on SM120.
| Component | Size |
|---|---|
| Weights (NVFP4) | 57.06 GB per GPU |
| KV Cache (bf16) | 29.32 GB per GPU |
| Total allocated | ~86.38 GB per GPU |
| Available after allocation | 7.43-7.53 GB per GPU |
| mem-fraction-static | Total KV Tokens | Max Context |
|---|---|---|
| 0.92 | 314,304 | ~202,752 |
| 0.85 | Slightly less | ~190,000 |
BF16 KV cache limits practical context to ~200K tokens. FP8 KV cache would allow more but is broken on SM120.
| GPUs | Configuration | Status |
|---|---|---|
| 8x | --tp 8 |
Primary configuration, well tested |
| 6x | --tp 2 --pp 3 |
Reported viable, less tested |
| 4x | N/A | Too large -- NVFP4 weights alone are 440 GB |
- No TMEM (Tensor Memory)
- No TCGEN05 instructions
- No WGMMA instructions
- Shorter shared memory / register file
- Cannot run DeepGemm (requires WGMMA for SM90, TCGEN05 for SM100)
- Cannot run FlashAttention 3+ (based on TMEM/TCGEN05)
- Cannot run FlashMLA Sparse natively
- Limited to FlashAttention 2 via SM89 kernels
SGLang bypasses all DSA backends and runs GLM-5 as a DeepSeek V3.1 model:
- Uses MLA kernels ignoring sparsity (FlashInfer FA2 variant)
- DSA indexer is not invoked
- Computes attention on all tokens (including those DSA would have masked)
- This is backwards compatible -- slightly wasteful but not accuracy-degrading
- FlashInfer FA-based BF16 MLA kernel (SM120 specific)
- XQA FP8 MLA kernel (SM120 specific)
Neither is available in vLLM as of 2026-03-08.
AttributeError: 'ImportError' object has no attribute 'get_num_sms'
Fix: Set SGLANG_ENABLE_JIT_DEEPGEMM=0 and SGLANG_ENABLE_DEEP_GEMM=0.
Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.
Fix: See FlashInfer CUTLASS Race Condition Fix section above.
eagle_worker_v2.py:510 _zero_fill_draft_kv_for_cached_prefix
torch.AcceleratorError: CUDA error: device-side assert triggered
Fix: SGLang PR sgl-project/sglang#19897. Root cause is the FlashInfer CUTLASS race condition.
RuntimeError: Assertion error (attention.hpp:159): Unsupported architecture
Fix: Override to FlashInfer backend. Set nsa_prefill_backend = "flashinfer" and nsa_decode_backend = "flashinfer" in server_args.py, or use voipmonitor/llm-pytorch-blackwell:nightly which includes this fix.
ValueError: No valid attention backend found for cuda with ... use_mla=True, use_sparse=True
Fix: None. GLM-5 does not run on vLLM for SM120. Use SGLang.
Fix: Use --kv-cache-dtype bf16. FP8 KV is broken on SM120.
ValueError: MTP speculative decoding layer 78 weights missing from checkpoint.
Fix: Use festr2/GLM-5-NVFP4-MTP which includes the MTP layer.
Config file not found at .../E=257,N=256,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition.json
Fix: Generate configs using https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton or use the voipmonitor/llm-pytorch-blackwell:nightly Docker which includes pre-generated configs.
| PR | Description |
|---|---|
| SGLang #19897 | Fix for radix cache + speculative decoding crash |
| SGLang #19948 | DeepGemm SCALE_UE8M0 fix for NVFP4 on SM120 |
| SGLang #19951 | Fix for broken latest SGLang |
| SGLang #19963 | Compilation fixes |
| SGLang #19428 | Performance improvement for GLM-5 |
| SGLang #20043 | Bug report: NaN crash with speculative decoding |
| FlashInfer #2708 | FlashInfer FP4 CUTLASS race condition |
| FlashInfer #2716 | FlashInfer fix for the race condition |

