Skip to content

Commit 36ab1af

Browse files
committed
feat: add all 7 missing DeepSeek features
- Absorbed MLA attention (no KV decompression at inference) - Group-limited MoE routing (n_expert_groups, n_limited_groups) - FP8 training framework (fp8_utils.py + --fp8 CLI flag) - Knowledge distillation pipeline (distill.py + CLI) - MTP speculative decoding (MTPSpeculativeGenerator) - FlashMLA optional CUDA integration (flash_mla.py) - DeepEP expert parallelism integration (expert_parallel.py)
1 parent 547b60c commit 36ab1af

File tree

7 files changed

+1081
-799
lines changed

7 files changed

+1081
-799
lines changed

supergpt/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class GPTConfig:
5454
aux_loss_free: bool = False # Use bias-based routing (DeepSeek V3)
5555
bias_update_speed: float = 0.001 # γ for bias adjustment
5656
moe_aux_loss_weight: float = 0.01 # Aux loss weight (when not aux_loss_free)
57+
n_expert_groups: int = 1 # Number of expert groups for group-limited routing
58+
n_limited_groups: int = 1 # Top groups to select from (DeepSeek V3)
5759

5860
# ── Multi-Token Prediction (MTP) — DeepSeek V3 ──────────────────────
5961
n_predict_tokens: int = 1 # Tokens to predict (1=standard, >1=MTP)

supergpt/core/flash_mla.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
FlashMLA Integration — Optional CUDA Kernel Acceleration
3+
==========================================================
4+
Provides a transparent interface to DeepSeek's FlashMLA CUDA kernels
5+
for MLA attention. Falls back to PyTorch implementation if not available.
6+
7+
FlashMLA achieves 660 TFLOPS on H800 SXM5 (vs ~50 TFLOPS with PyTorch).
8+
Requires: SM90+ GPU (H100/H800/B200), CUDA 12.8+, flash_mla package.
9+
10+
Install FlashMLA:
11+
git clone https://github.com/deepseek-ai/FlashMLA.git
12+
cd FlashMLA && pip install -v .
13+
14+
Usage in superGPT:
15+
Automatic — if flash_mla is installed, MLA attention auto-routes to CUDA.
16+
"""
17+
18+
import torch
19+
from typing import Optional, Tuple
20+
21+
# Try to import FlashMLA CUDA kernels
22+
FLASH_MLA_AVAILABLE = False
23+
try:
24+
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
25+
FLASH_MLA_AVAILABLE = True
26+
except ImportError:
27+
pass
28+
29+
30+
def get_mla_backend() -> str:
31+
"""Return the active MLA backend name."""
32+
if FLASH_MLA_AVAILABLE:
33+
return "flash_mla (CUDA)"
34+
return "pytorch (naive/absorbed)"
35+
36+
37+
def flash_mla_decode(
38+
q: torch.Tensor, # (B, n_heads, 1, qk_head_dim)
39+
kv_cache: torch.Tensor, # (B, max_seq_len, kv_lora_rank + qk_rope_dim)
40+
cache_seqlens: torch.Tensor, # (B,) — actual sequence lengths
41+
block_table: torch.Tensor, # Block table for paged attention
42+
softmax_scale: float,
43+
kv_lora_rank: int,
44+
qk_rope_head_dim: int,
45+
v_head_dim: int,
46+
) -> torch.Tensor:
47+
"""Run FlashMLA decode kernel if available.
48+
49+
This handles the case where we're generating one token at a time
50+
with a KV-cache. FlashMLA is optimized for this decode path.
51+
52+
Returns:
53+
output: (B, n_heads, 1, v_head_dim) attention output
54+
55+
Falls back to None if FlashMLA is not available (caller should
56+
use the PyTorch absorbed/naive path instead).
57+
"""
58+
if not FLASH_MLA_AVAILABLE:
59+
return None
60+
61+
try:
62+
# FlashMLA expects specific tensor layouts
63+
# Reshape q for FlashMLA: (B, n_heads, 1, head_dim) → (B, n_heads, head_dim)
64+
B, n_heads, _, head_dim = q.shape
65+
q_squeezed = q.squeeze(2) # (B, n_heads, head_dim)
66+
67+
# Get metadata for block-sparse attention
68+
tile_scheduler_metadata, num_splits = get_mla_metadata(
69+
cache_seqlens,
70+
block_table.shape[-1], # num_blocks_per_seq
71+
)
72+
73+
# Run FlashMLA kernel
74+
output, _ = flash_mla_with_kvcache(
75+
q_squeezed,
76+
kv_cache,
77+
block_table,
78+
cache_seqlens,
79+
kv_lora_rank,
80+
tile_scheduler_metadata,
81+
num_splits,
82+
softmax_scale=softmax_scale,
83+
)
84+
85+
return output.unsqueeze(2) # (B, n_heads, 1, v_head_dim)
86+
87+
except Exception as e:
88+
# Any error: fall back to PyTorch path
89+
print(f" FlashMLA error (falling back to PyTorch): {e}")
90+
return None
91+
92+
93+
def flash_mla_prefill(
94+
q: torch.Tensor, # (B, n_heads, T, qk_head_dim)
95+
kv_latent: torch.Tensor, # (B, T, kv_lora_rank)
96+
k_rope: torch.Tensor, # (B, 1, T, qk_rope_head_dim)
97+
softmax_scale: float,
98+
causal: bool = True,
99+
) -> Optional[torch.Tensor]:
100+
"""Run FlashMLA prefill kernel if available.
101+
102+
For the prefill (training) path with full sequence.
103+
Falls back to None if not available.
104+
"""
105+
if not FLASH_MLA_AVAILABLE:
106+
return None
107+
108+
# FlashMLA prefill support depends on version
109+
# For now, return None to use PyTorch naive path for prefill
110+
# (FlashMLA prefill is optimized for SM100+ / B200)
111+
return None
112+
113+
114+
def print_flash_mla_info():
115+
"""Print FlashMLA availability and info."""
116+
if FLASH_MLA_AVAILABLE:
117+
print(" FlashMLA: ✅ Available (CUDA kernels)")
118+
print(" Expected: 660 TFLOPS (H800), 1450 TFLOPS (B200)")
119+
else:
120+
print(" FlashMLA: ❌ Not installed (using PyTorch)")
121+
print(" Install: git clone https://github.com/deepseek-ai/FlashMLA.git")
122+
print(" cd FlashMLA && pip install -v .")

0 commit comments

Comments
 (0)