Skip to content

Commit 3e34e19

Browse files
yubofredwangtimmy-fengsleepcooFrankLeeeeeuygnef
authored
Add flash attn backend, duplicates #314 (#400)
* added flash_attn backend * fix pre commit * replace requirements.txt with pyproject.toml for flash attn compilation * fix pre commit * fix deps * lint * bump flash-attn * update ci image * test fa3 * fix bug * fix bug * fix bug * Update Docker image version in test workflow * Update pip install command in test workflow Add --no-build-isolation flag to pip install command * Update pyproject.toml * Add setuptools installation to workflow * Update test.yaml * Update test.yaml * Refactor test workflow to eliminate redundancy Removed duplicate test run commands and unnecessary ls statements. * Update pyproject.toml * polish * polish * polish * polish * polish * polish * polish * polish * polish * fix position id * clean up * polish --------- Co-authored-by: timmy-feng <timothy@modal.com> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: Shenggui Li <somerlee.9@gmail.com> Co-authored-by: Yu Feng <admin@fengyu.org>
1 parent 23bf5e4 commit 3e34e19

File tree

9 files changed

+435
-35
lines changed

9 files changed

+435
-35
lines changed

.github/workflows/test.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ jobs:
4646
uv venv sf -p 3.11
4747
fi
4848
source sf/bin/activate
49-
uv pip install -v . --prerelease=allow
49+
uv pip install setuptools
50+
MAX_JOBS=8 uv pip install -v ".[fa]" --prerelease=allow --no-build-isolation
5051
5152
- name: Run test
5253
timeout-minutes: 30

pyproject.toml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "specforge"
7+
dynamic = ["version", "description"]
8+
readme = "README.md"
9+
requires-python = ">=3.11"
10+
dependencies = [
11+
"pre-commit",
12+
"torch==2.9.1",
13+
"torchaudio==2.9.1",
14+
"torchvision==0.24.1",
15+
"transformers==4.57.1",
16+
"qwen-vl-utils==0.0.11",
17+
"datasets",
18+
"setuptools",
19+
"tqdm",
20+
"wandb",
21+
"psutil",
22+
"numpy",
23+
"accelerate",
24+
"pydantic",
25+
"sglang==0.5.6",
26+
"openai-harmony",
27+
"ninja",
28+
"packaging",
29+
"yunchang",
30+
]
31+
32+
[tool.setuptools]
33+
packages = ["specforge"]
34+
35+
[project.optional-dependencies]
36+
dev = [
37+
"pre-commit",
38+
"unittest"
39+
]
40+
fa = ["flash-attn"]
41+
42+
[tool.setuptools.dynamic]
43+
version = {file = "version.txt"}
44+
description = {file = "README.md"}

requirements.txt

Lines changed: 0 additions & 17 deletions
This file was deleted.

scripts/regenerate_train_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,11 @@ def main():
292292
error_samples = 0
293293

294294
# Create progress bar
295-
with open(args.input_file_path, "r") as input_file, open(
296-
args.output_file_path, "w"
297-
) as output_file_handle, open(error_file_path, "w") as error_file_handle:
295+
with (
296+
open(args.input_file_path, "r") as input_file,
297+
open(args.output_file_path, "w") as output_file_handle,
298+
open(error_file_path, "w") as error_file_handle,
299+
):
298300
executor = ThreadPoolExecutor(
299301
max_workers=args.concurrency * len(valid_server_addresses)
300302
)

setup.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from setuptools import find_packages, setup
2-
1+
import tomllib
2+
from pathlib import Path
33

4-
def read_requirements():
5-
with open(f"requirements.txt", "r") as f:
6-
lines = (line.strip() for line in f)
7-
return [line for line in lines if line and not line.startswith(("#", "--"))]
4+
from setuptools import find_packages, setup
85

96

107
def read_readme():
@@ -17,11 +14,18 @@ def read_version():
1714
return f.read().strip()
1815

1916

17+
def read_dependencies():
18+
pyproject_path = Path(__file__).parent / "pyproject.toml"
19+
with open(pyproject_path, "rb") as f:
20+
pyproject = tomllib.load(f)
21+
return pyproject.get("project", {}).get("dependencies", [])
22+
23+
2024
setup(
2125
name="specforge",
2226
packages=find_packages(exclude=["configs", "scripts", "tests"]),
2327
version=read_version(),
24-
install_requires=read_requirements(),
28+
install_requires=read_dependencies(),
2529
long_description=read_readme(),
2630
long_description_content_type="text/markdown",
2731
author="SGLang Team",

specforge/core/eagle3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def forward(
165165
acces = []
166166
# for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift
167167
global_input_ids = input_ids
168-
if self.attention_backend == "sdpa":
168+
if self.attention_backend in ["sdpa", "fa"]:
169169
cache_hidden = [[], []]
170170
past_key_values = None
171171
elif self.attention_backend == "flex_attention":
@@ -175,6 +175,8 @@ def forward(
175175
cache_hidden = [[], []]
176176
past_key_values = None
177177
hidden_states = self.prepare_usp_input(hidden_states)
178+
else:
179+
raise ValueError(f"Unknown attention backend: {self.attention_backend}")
178180

179181
for idx in range(self.length):
180182
target_p = target_p_padded[:, idx : idx + seq_length, :]
@@ -464,12 +466,14 @@ def forward(
464466
plosses = []
465467
vlosses = []
466468
acces = []
467-
if self.attention_backend == "sdpa":
469+
if self.attention_backend in ["sdpa", "fa"]:
468470
cache_hidden = [[], []]
469471
past_key_values = None
470472
elif self.attention_backend == "flex_attention":
471473
cache_hidden = None
472474
past_key_values = DynamicCache()
475+
else:
476+
raise ValueError(f"Unknown attention backend: {self.attention_backend}")
473477

474478
for idx in range(self.length):
475479
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()

specforge/data/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def prepare_dp_dataloaders(
237237
shuffle: Optional[bool] = False,
238238
is_vlm: Optional[bool] = False,
239239
prefetch_factor: Optional[int] = 2,
240-
**dataloader_kwargs
240+
**dataloader_kwargs,
241241
) -> DataLoader:
242242
"""
243243
Prepare dataloader for distributed data parallel training.
@@ -277,7 +277,7 @@ def prepare_dp_dataloaders(
277277
prefetch_factor=prefetch_factor,
278278
collate_fn=datacollator_cls(),
279279
drop_last=True,
280-
**dataloader_kwargs
280+
**dataloader_kwargs,
281281
)
282282
return dataloader
283283

specforge/modeling/draft/llama3_eagle.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23
from typing import List, Optional, Tuple
34

45
import torch
@@ -23,6 +24,14 @@
2324
from ...distributed import get_sp_ring_group, get_sp_ulysses_group
2425
from .base import Eagle3DraftModel
2526

27+
try:
28+
from flash_attn import flash_attn_func
29+
except:
30+
warnings.warn(
31+
"flash_attn is not found, please install flash_attn if you want to use the flash attention backend"
32+
)
33+
flash_attn_func = None
34+
2635

2736
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
2837
def _make_causal_mask(
@@ -94,12 +103,12 @@ def rotate_half(x):
94103

95104

96105
@torch.compile(dynamic=True)
97-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
106+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
98107
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
99108
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
100109
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
101-
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
102-
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
110+
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
111+
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
103112
q_embed = (q * cos) + (rotate_half(q) * sin)
104113
k_embed = (k * cos) + (rotate_half(k) * sin)
105114
return q_embed, k_embed
@@ -1170,6 +1179,120 @@ def forward(
11701179
return attn_output
11711180

11721181

1182+
class LlamaFlashAttention(LlamaAttention):
1183+
"""
1184+
Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention.
1185+
The used parameters are:
1186+
- hidden_states: input hidden states
1187+
- position_ids: position ids
1188+
- cache_hidden: manual cache used for storing past key and value states
1189+
"""
1190+
1191+
def forward(
1192+
self,
1193+
hidden_states: torch.Tensor,
1194+
cache_hidden: Optional[List[torch.Tensor]] = None,
1195+
attention_mask: Optional[torch.Tensor] = None,
1196+
position_ids: Optional[torch.LongTensor] = None,
1197+
past_key_values: Optional[Cache] = None,
1198+
output_attentions: bool = False,
1199+
use_cache: bool = False,
1200+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1201+
bsz, q_len, _ = hidden_states.size()
1202+
1203+
query_states = self.q_proj(hidden_states)
1204+
key_states = self.k_proj(hidden_states)
1205+
value_states = self.v_proj(hidden_states)
1206+
1207+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
1208+
key_states = key_states.view(
1209+
bsz, q_len, self.num_key_value_heads, self.head_dim
1210+
)
1211+
value_states = value_states.view(
1212+
bsz, q_len, self.num_key_value_heads, self.head_dim
1213+
)
1214+
1215+
lck = 0 if cache_hidden is None else len(cache_hidden[0])
1216+
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
1217+
cos, sin = self.rotary_emb(query_states, position_ids + lck)
1218+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
1219+
query_states, key_states = apply_multimodal_rotary_pos_emb(
1220+
query_states,
1221+
key_states,
1222+
cos,
1223+
sin,
1224+
self.config.rope_scaling["mrope_section"],
1225+
unsqueeze_dim=2,
1226+
)
1227+
else:
1228+
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
1229+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
1230+
query_states, key_states = apply_rotary_pos_emb(
1231+
query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2
1232+
)
1233+
1234+
if cache_hidden is not None:
1235+
cache_hidden[0] = cache_hidden[0] + [key_states]
1236+
cache_hidden[1] = cache_hidden[1] + [value_states]
1237+
1238+
cache_k = cache_hidden[0]
1239+
cache_v = cache_hidden[1]
1240+
else:
1241+
cache_k = [key_states]
1242+
cache_v = [value_states]
1243+
1244+
k0 = cache_k[0]
1245+
v0 = cache_v[0]
1246+
1247+
assert (
1248+
flash_attn_func is not None
1249+
), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend"
1250+
attn_output, lse, _ = flash_attn_func(
1251+
query_states,
1252+
k0,
1253+
v0,
1254+
dropout_p=0.0,
1255+
softmax_scale=1.0 / math.sqrt(self.head_dim),
1256+
causal=True,
1257+
return_attn_probs=True,
1258+
)
1259+
lse = lse.transpose(1, 2)
1260+
1261+
lck = len(cache_k)
1262+
if lck > 1:
1263+
q_shape_expanded = (
1264+
bsz,
1265+
q_len,
1266+
self.num_key_value_heads,
1267+
self.num_key_value_groups,
1268+
self.head_dim,
1269+
)
1270+
attn_outputs = [attn_output.view(q_shape_expanded)]
1271+
lses = [lse.view(q_shape_expanded[:-1])]
1272+
1273+
for i in range(1, lck):
1274+
ki = cache_k[i].unsqueeze(-2)
1275+
qi = query_states.view(q_shape_expanded)
1276+
vi = cache_v[i].unsqueeze(-2)
1277+
1278+
attn_outputs.append(vi)
1279+
lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim))
1280+
1281+
lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1)
1282+
attn_output = sum(
1283+
attn_outputi * torch.exp(lsei - lse).unsqueeze(-1)
1284+
for attn_outputi, lsei in zip(attn_outputs, lses)
1285+
)
1286+
# lse is fp32, downcast attn_output back
1287+
attn_output = attn_output.to(self.o_proj.weight.dtype)
1288+
1289+
attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads)
1290+
1291+
attn_output = self.o_proj(attn_output)
1292+
1293+
return attn_output
1294+
1295+
11731296
class LlamaMLP(nn.Module):
11741297
def __init__(self, config):
11751298
super().__init__()
@@ -1245,6 +1368,8 @@ def __init__(self, config, attention_backend: str = "sdpa"):
12451368
elif attention_backend == "flex_attention":
12461369
print_with_rank("Using flex attention on draft model training!")
12471370
self.self_attn = LlamaFlexAttention(config=config)
1371+
elif attention_backend == "fa":
1372+
self.self_attn = LlamaFlashAttention(config=config)
12481373
else:
12491374
raise ValueError(f"Unknown attention backend {attention_backend}")
12501375

0 commit comments

Comments
 (0)