Skip to content

Commit bf7f659

Browse files
committed
clean up
1 parent 272799c commit bf7f659

File tree

4 files changed

+86
-33
lines changed

4 files changed

+86
-33
lines changed

scripts/train_eagle3.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
destroy_distributed,
3636
get_dp_group,
3737
get_draft_dp_group,
38+
get_draft_sp_group,
3839
get_tp_group,
39-
init_distributed, get_draft_sp_group,
40+
init_distributed,
4041
)
4142
from specforge.modeling.target import (
4243
Eagle3TargetModel,
@@ -631,11 +632,6 @@ def record_metrcs(
631632
tracker.log(logdict, step=global_step)
632633

633634

634-
import torch
635-
import torch.distributed as dist
636-
import torch.nn.functional as F
637-
638-
639635
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
640636
"""
641637
Process: TP split -> Pad to Max Len -> SP gather.
@@ -657,7 +653,9 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
657653
local_seq_len = local_tp_shard.size(sp_dim)
658654

659655
# Find global max sequence length in SP group
660-
len_tensor = torch.tensor([local_seq_len], device=local_tp_shard.device, dtype=torch.long)
656+
len_tensor = torch.tensor(
657+
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
658+
)
661659
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
662660
max_seq_len = len_tensor.item()
663661

@@ -674,12 +672,16 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
674672
pad_config[pad_idx] = pad_size
675673

676674
# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
677-
local_tp_shard_padded = F.pad(local_tp_shard, pad_config, value=0)
675+
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
678676
else:
679677
local_tp_shard_padded = local_tp_shard
680678

681-
gathered_shards = [torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)]
682-
dist.all_gather(gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group)
679+
gathered_shards = [
680+
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
681+
]
682+
dist.all_gather(
683+
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
684+
)
683685

684686
return torch.cat(gathered_shards, dim=sp_dim)
685687

specforge/layers/ring/ring_flash_attn.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from yunchang.kernels import AttnType, select_flash_attn_impl
3+
24
from .utils import RingComm, update_out_and_lse
3-
from yunchang.kernels import select_flash_attn_impl, AttnType
5+
46

57
def ring_flash_attn_forward(
68
process_group,
@@ -31,7 +33,9 @@ def ring_flash_attn_forward(
3133
comm.commit()
3234

3335
if not causal or step <= comm.rank:
34-
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
36+
fn = select_flash_attn_impl(
37+
attn_type, stage="fwd-only", attn_processor=attn_processor
38+
)
3539
block_out, block_lse = fn(
3640
q,
3741
k,
@@ -219,7 +223,22 @@ def backward(ctx, dout, *args):
219223
deterministic=ctx.deterministic,
220224
attn_type=ctx.attn_type,
221225
)
222-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
226+
return (
227+
dq,
228+
dk,
229+
dv,
230+
None,
231+
None,
232+
None,
233+
None,
234+
None,
235+
None,
236+
None,
237+
None,
238+
None,
239+
None,
240+
None,
241+
)
223242

224243

225244
def ring_flash_attn_qkvpacked_func(

specforge/layers/ring/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
__all__ = ["update_out_and_lse", "RingComm"]
88

9+
910
@torch.jit.script
1011
def _update_out_and_lse(
1112
out: torch.Tensor,
1213
lse: torch.Tensor,
1314
block_out: torch.Tensor,
1415
block_lse: torch.Tensor,
1516
) -> Tuple[torch.Tensor, torch.Tensor]:
16-
17+
1718
block_out = block_out.to(torch.float32)
1819
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
1920

@@ -115,4 +116,4 @@ def wait(self):
115116
for req in self._reqs:
116117
req.wait()
117118
self._reqs = None
118-
self._ops = []
119+
self._ops = []

specforge/modeling/draft/llama3_eagle.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,21 @@
77
import torch.nn.functional as F
88
from flash_attn import flash_attn_func
99
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
10-
from transformers import LlamaConfig
1110
from transformers.activations import ACT2FN
1211
from transformers.cache_utils import Cache
1312
from transformers.models.llama.configuration_llama import LlamaConfig
14-
from yunchang import EXTRACT_FUNC_DICT
1513
from yunchang.comm import SeqAllToAll4D
16-
from flash_attn import flash_attn_func
1714

1815
from specforge.modeling.draft.flex_attention import (
1916
compile_friendly_create_block_mask,
2017
compile_friendly_flex_attention,
2118
generate_eagle3_mask,
2219
)
2320
from specforge.utils import print_with_rank
24-
from .base import Eagle3DraftModel
21+
2522
from ...distributed import get_sp_ring_group, get_sp_ulysses_group
2623
from ...layers.ring import ring_flash_attn_func
27-
24+
from .base import Eagle3DraftModel
2825

2926

3027
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -962,6 +959,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
962959
"""
963960
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
964961
"""
962+
965963
def __init__(self, config):
966964
super().__init__(config)
967965
assert (
@@ -997,19 +995,35 @@ def forward(
997995
query_states = self.q_proj(hidden_states)
998996
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
999997
query_states = SeqAllToAll4D.apply(
1000-
self.ulysses_pg, query_states, self.scatter_idx, self.gather_idx, self.use_sync
998+
self.ulysses_pg,
999+
query_states,
1000+
self.scatter_idx,
1001+
self.gather_idx,
1002+
self.use_sync,
10011003
)
10021004

10031005
key_states = self.k_proj(hidden_states)
1004-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1006+
key_states = key_states.view(
1007+
bsz, q_len, self.num_key_value_heads, self.head_dim
1008+
)
10051009
key_states = SeqAllToAll4D.apply(
1006-
self.ulysses_pg, key_states, self.scatter_idx, self.gather_idx, self.use_sync
1010+
self.ulysses_pg,
1011+
key_states,
1012+
self.scatter_idx,
1013+
self.gather_idx,
1014+
self.use_sync,
10071015
)
10081016

10091017
value_states = self.v_proj(hidden_states)
1010-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1018+
value_states = value_states.view(
1019+
bsz, q_len, self.num_key_value_heads, self.head_dim
1020+
)
10111021
value_states = SeqAllToAll4D.apply(
1012-
self.ulysses_pg, value_states, self.scatter_idx, self.gather_idx, self.use_sync
1022+
self.ulysses_pg,
1023+
value_states,
1024+
self.scatter_idx,
1025+
self.gather_idx,
1026+
self.use_sync,
10131027
)
10141028

10151029
current_q_len = query_states.shape[1]
@@ -1023,17 +1037,26 @@ def forward(
10231037
# =============================================================
10241038
if self.sp_ring_degree > 1:
10251039
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
1026-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[self.ring_rank].clone()
1040+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
1041+
self.ring_rank
1042+
].clone()
10271043
else:
1028-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[self.ring_rank].clone()
1044+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
1045+
self.ring_rank
1046+
].clone()
10291047

10301048
lck = 0 if cache_hidden is None else len(cache_hidden[0])
10311049

10321050
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
10331051
cos, sin = self.rotary_emb(query_states, position_ids + lck)
10341052
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
10351053
query_states, key_states = apply_multimodal_rotary_pos_emb(
1036-
query_states, key_states, cos, sin, self.config.rope_scaling["mrope_section"], unsqueeze_dim=2
1054+
query_states,
1055+
key_states,
1056+
cos,
1057+
sin,
1058+
self.config.rope_scaling["mrope_section"],
1059+
unsqueeze_dim=2,
10371060
)
10381061
else:
10391062
cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck)
@@ -1076,8 +1099,9 @@ def forward(
10761099
else:
10771100
acc_lse = lse_ring
10781101

1079-
assert acc_lse.shape[1] == current_q_len, \
1080-
f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
1102+
assert (
1103+
acc_lse.shape[1] == current_q_len
1104+
), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
10811105

10821106
acc_out = out_ring
10831107

@@ -1086,7 +1110,13 @@ def forward(
10861110
num_kv_heads_local = cache_k[0].shape[2]
10871111
local_groups = local_num_heads // num_kv_heads_local
10881112

1089-
q_shape_expanded = (bsz, current_q_len, num_kv_heads_local, local_groups, self.head_dim)
1113+
q_shape_expanded = (
1114+
bsz,
1115+
current_q_len,
1116+
num_kv_heads_local,
1117+
local_groups,
1118+
self.head_dim,
1119+
)
10901120
qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D]
10911121

10921122
for i in range(1, len(cache_k)):
@@ -1107,8 +1137,9 @@ def forward(
11071137
# Online Softmax Update
11081138
new_lse = torch.logaddexp(acc_lse, step_lse)
11091139

1110-
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(-1) + \
1111-
step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
1140+
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(
1141+
-1
1142+
) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
11121143

11131144
acc_lse = new_lse
11141145

0 commit comments

Comments
 (0)