Skip to content

Commit 2e9bdff

Browse files
committed
clean up
1 parent 638cc7a commit 2e9bdff

File tree

5 files changed

+75
-25
lines changed

5 files changed

+75
-25
lines changed

scripts/train_eagle3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ def build_dataloaders(
434434
num_workers=args.dataloader_num_workers,
435435
shuffle=True,
436436
process_group=(
437-
get_draft_dp_group()
438-
if args.attention_backend == "usp"
439-
else get_dp_group()
437+
get_draft_dp_group() if args.attention_backend == "usp" else get_dp_group()
440438
),
441439
is_vlm=args.is_vlm,
442440
)

specforge/layers/ring/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@
44
ring_flash_attn_kvpacked_func,
55
ring_flash_attn_qkvpacked_func,
66
)
7-

specforge/layers/ring/ring_flash_attn.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from yunchang.kernels import AttnType, select_flash_attn_impl
3+
24
from .utils import RingComm, update_out_and_lse
3-
from yunchang.kernels import select_flash_attn_impl, AttnType
5+
46

57
def ring_flash_attn_forward(
68
process_group,
@@ -31,7 +33,9 @@ def ring_flash_attn_forward(
3133
comm.commit()
3234

3335
if not causal or step <= comm.rank:
34-
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
36+
fn = select_flash_attn_impl(
37+
attn_type, stage="fwd-only", attn_processor=attn_processor
38+
)
3539
block_out, block_lse = fn(
3640
q,
3741
k,
@@ -219,7 +223,22 @@ def backward(ctx, dout, *args):
219223
deterministic=ctx.deterministic,
220224
attn_type=ctx.attn_type,
221225
)
222-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
226+
return (
227+
dq,
228+
dk,
229+
dv,
230+
None,
231+
None,
232+
None,
233+
None,
234+
None,
235+
None,
236+
None,
237+
None,
238+
None,
239+
None,
240+
None,
241+
)
223242

224243

225244
def ring_flash_attn_qkvpacked_func(

specforge/layers/ring/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
__all__ = ["update_out_and_lse", "RingComm"]
88

9+
910
@torch.jit.script
1011
def _update_out_and_lse(
1112
out: torch.Tensor,
1213
lse: torch.Tensor,
1314
block_out: torch.Tensor,
1415
block_lse: torch.Tensor,
1516
) -> Tuple[torch.Tensor, torch.Tensor]:
16-
17+
1718
block_out = block_out.to(torch.float32)
1819
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
1920

@@ -115,4 +116,4 @@ def wait(self):
115116
for req in self._reqs:
116117
req.wait()
117118
self._reqs = None
118-
self._ops = []
119+
self._ops = []

specforge/modeling/draft/llama3_eagle.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from transformers.activations import ACT2FN
1212
from transformers.cache_utils import Cache
1313
from transformers.models.llama.configuration_llama import LlamaConfig
14-
from yunchang import EXTRACT_FUNC_DICT
1514
from yunchang.comm import SeqAllToAll4D
1615

1716
from specforge.modeling.draft.flex_attention import (
@@ -20,10 +19,10 @@
2019
generate_eagle3_mask,
2120
)
2221
from specforge.utils import print_with_rank
23-
from .base import Eagle3DraftModel
22+
2423
from ...distributed import get_sp_ring_group, get_sp_ulysses_group
2524
from ...layers.ring import ring_flash_attn_func
26-
25+
from .base import Eagle3DraftModel
2726

2827

2928
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -961,6 +960,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
961960
"""
962961
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
963962
"""
963+
964964
def __init__(self, config):
965965
super().__init__(config)
966966
assert (
@@ -996,19 +996,35 @@ def forward(
996996
query_states = self.q_proj(hidden_states)
997997
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
998998
query_states = SeqAllToAll4D.apply(
999-
self.ulysses_pg, query_states, self.scatter_idx, self.gather_idx, self.use_sync
999+
self.ulysses_pg,
1000+
query_states,
1001+
self.scatter_idx,
1002+
self.gather_idx,
1003+
self.use_sync,
10001004
)
10011005

10021006
key_states = self.k_proj(hidden_states)
1003-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1007+
key_states = key_states.view(
1008+
bsz, q_len, self.num_key_value_heads, self.head_dim
1009+
)
10041010
key_states = SeqAllToAll4D.apply(
1005-
self.ulysses_pg, key_states, self.scatter_idx, self.gather_idx, self.use_sync
1011+
self.ulysses_pg,
1012+
key_states,
1013+
self.scatter_idx,
1014+
self.gather_idx,
1015+
self.use_sync,
10061016
)
10071017

10081018
value_states = self.v_proj(hidden_states)
1009-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1019+
value_states = value_states.view(
1020+
bsz, q_len, self.num_key_value_heads, self.head_dim
1021+
)
10101022
value_states = SeqAllToAll4D.apply(
1011-
self.ulysses_pg, value_states, self.scatter_idx, self.gather_idx, self.use_sync
1023+
self.ulysses_pg,
1024+
value_states,
1025+
self.scatter_idx,
1026+
self.gather_idx,
1027+
self.use_sync,
10121028
)
10131029

10141030
current_q_len = query_states.shape[1]
@@ -1022,17 +1038,26 @@ def forward(
10221038
# =============================================================
10231039
if self.sp_ring_degree > 1:
10241040
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
1025-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[self.ring_rank].clone()
1041+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
1042+
self.ring_rank
1043+
].clone()
10261044
else:
1027-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[self.ring_rank].clone()
1045+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
1046+
self.ring_rank
1047+
].clone()
10281048

10291049
lck = 0 if cache_hidden is None else len(cache_hidden[0])
10301050

10311051
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
10321052
cos, sin = self.rotary_emb(query_states, position_ids + lck)
10331053
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
10341054
query_states, key_states = apply_multimodal_rotary_pos_emb(
1035-
query_states, key_states, cos, sin, self.config.rope_scaling["mrope_section"], unsqueeze_dim=2
1055+
query_states,
1056+
key_states,
1057+
cos,
1058+
sin,
1059+
self.config.rope_scaling["mrope_section"],
1060+
unsqueeze_dim=2,
10361061
)
10371062
else:
10381063
cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck)
@@ -1075,8 +1100,9 @@ def forward(
10751100
else:
10761101
acc_lse = lse_ring
10771102

1078-
assert acc_lse.shape[1] == current_q_len, \
1079-
f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
1103+
assert (
1104+
acc_lse.shape[1] == current_q_len
1105+
), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
10801106

10811107
acc_out = out_ring
10821108

@@ -1085,7 +1111,13 @@ def forward(
10851111
num_kv_heads_local = cache_k[0].shape[2]
10861112
local_groups = local_num_heads // num_kv_heads_local
10871113

1088-
q_shape_expanded = (bsz, current_q_len, num_kv_heads_local, local_groups, self.head_dim)
1114+
q_shape_expanded = (
1115+
bsz,
1116+
current_q_len,
1117+
num_kv_heads_local,
1118+
local_groups,
1119+
self.head_dim,
1120+
)
10891121
qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D]
10901122

10911123
for i in range(1, len(cache_k)):
@@ -1106,8 +1138,9 @@ def forward(
11061138
# Online Softmax Update
11071139
new_lse = torch.logaddexp(acc_lse, step_lse)
11081140

1109-
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(-1) + \
1110-
step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
1141+
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(
1142+
-1
1143+
) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
11111144

11121145
acc_lse = new_lse
11131146

0 commit comments

Comments
 (0)