Skip to content

Commit bba2f9d

Browse files
committed
clean up
1 parent 76d3cf3 commit bba2f9d

File tree

4 files changed

+86
-32
lines changed

4 files changed

+86
-32
lines changed

scripts/train_eagle3.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
destroy_distributed,
3636
get_dp_group,
3737
get_draft_dp_group,
38+
get_draft_sp_group,
3839
get_tp_group,
39-
init_distributed, get_draft_sp_group,
40+
init_distributed,
4041
)
4142
from specforge.modeling.target import (
4243
Eagle3TargetModel,
@@ -631,11 +632,6 @@ def record_metrcs(
631632
tracker.log(logdict, step=global_step)
632633

633634

634-
import torch
635-
import torch.distributed as dist
636-
import torch.nn.functional as F
637-
638-
639635
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
640636
"""
641637
Process: TP split -> Pad to Max Len -> SP gather.
@@ -657,7 +653,9 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
657653
local_seq_len = local_tp_shard.size(sp_dim)
658654

659655
# Find global max sequence length in SP group
660-
len_tensor = torch.tensor([local_seq_len], device=local_tp_shard.device, dtype=torch.long)
656+
len_tensor = torch.tensor(
657+
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
658+
)
661659
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
662660
max_seq_len = len_tensor.item()
663661

@@ -674,12 +672,16 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
674672
pad_config[pad_idx] = pad_size
675673

676674
# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
677-
local_tp_shard_padded = F.pad(local_tp_shard, pad_config, value=0)
675+
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
678676
else:
679677
local_tp_shard_padded = local_tp_shard
680678

681-
gathered_shards = [torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)]
682-
dist.all_gather(gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group)
679+
gathered_shards = [
680+
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
681+
]
682+
dist.all_gather(
683+
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
684+
)
683685

684686
return torch.cat(gathered_shards, dim=sp_dim)
685687

specforge/layers/ring/ring_flash_attn.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from yunchang.kernels import AttnType, select_flash_attn_impl
3+
24
from .utils import RingComm, update_out_and_lse
3-
from yunchang.kernels import select_flash_attn_impl, AttnType
5+
46

57
def ring_flash_attn_forward(
68
process_group,
@@ -31,7 +33,9 @@ def ring_flash_attn_forward(
3133
comm.commit()
3234

3335
if not causal or step <= comm.rank:
34-
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
36+
fn = select_flash_attn_impl(
37+
attn_type, stage="fwd-only", attn_processor=attn_processor
38+
)
3539
block_out, block_lse = fn(
3640
q,
3741
k,
@@ -219,7 +223,22 @@ def backward(ctx, dout, *args):
219223
deterministic=ctx.deterministic,
220224
attn_type=ctx.attn_type,
221225
)
222-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
226+
return (
227+
dq,
228+
dk,
229+
dv,
230+
None,
231+
None,
232+
None,
233+
None,
234+
None,
235+
None,
236+
None,
237+
None,
238+
None,
239+
None,
240+
None,
241+
)
223242

224243

225244
def ring_flash_attn_qkvpacked_func(

specforge/layers/ring/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
__all__ = ["update_out_and_lse", "RingComm"]
88

9+
910
@torch.jit.script
1011
def _update_out_and_lse(
1112
out: torch.Tensor,
1213
lse: torch.Tensor,
1314
block_out: torch.Tensor,
1415
block_lse: torch.Tensor,
1516
) -> Tuple[torch.Tensor, torch.Tensor]:
16-
17+
1718
block_out = block_out.to(torch.float32)
1819
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
1920

@@ -115,4 +116,4 @@ def wait(self):
115116
for req in self._reqs:
116117
req.wait()
117118
self._reqs = None
118-
self._ops = []
119+
self._ops = []

specforge/modeling/draft/llama3_eagle.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
import torch.nn.functional as F
99
from flash_attn import flash_attn_func
1010
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
11-
from transformers import LlamaConfig
1211
from transformers.activations import ACT2FN
1312
from transformers.cache_utils import Cache
1413
from transformers.models.llama.configuration_llama import LlamaConfig
15-
from yunchang import EXTRACT_FUNC_DICT
1614
from yunchang.comm import SeqAllToAll4D
1715

1816
from specforge.modeling.draft.flex_attention import (
@@ -21,10 +19,10 @@
2119
generate_eagle3_mask,
2220
)
2321
from specforge.utils import print_with_rank
24-
from .base import Eagle3DraftModel
22+
2523
from ...distributed import get_sp_ring_group, get_sp_ulysses_group
2624
from ...layers.ring import ring_flash_attn_func
27-
25+
from .base import Eagle3DraftModel
2826

2927
try:
3028
from flash_attn import flash_attn_func
@@ -973,6 +971,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
973971
"""
974972
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
975973
"""
974+
976975
def __init__(self, config):
977976
super().__init__(config)
978977
assert (
@@ -1008,19 +1007,35 @@ def forward(
10081007
query_states = self.q_proj(hidden_states)
10091008
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
10101009
query_states = SeqAllToAll4D.apply(
1011-
self.ulysses_pg, query_states, self.scatter_idx, self.gather_idx, self.use_sync
1010+
self.ulysses_pg,
1011+
query_states,
1012+
self.scatter_idx,
1013+
self.gather_idx,
1014+
self.use_sync,
10121015
)
10131016

10141017
key_states = self.k_proj(hidden_states)
1015-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1018+
key_states = key_states.view(
1019+
bsz, q_len, self.num_key_value_heads, self.head_dim
1020+
)
10161021
key_states = SeqAllToAll4D.apply(
1017-
self.ulysses_pg, key_states, self.scatter_idx, self.gather_idx, self.use_sync
1022+
self.ulysses_pg,
1023+
key_states,
1024+
self.scatter_idx,
1025+
self.gather_idx,
1026+
self.use_sync,
10181027
)
10191028

10201029
value_states = self.v_proj(hidden_states)
1021-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1030+
value_states = value_states.view(
1031+
bsz, q_len, self.num_key_value_heads, self.head_dim
1032+
)
10221033
value_states = SeqAllToAll4D.apply(
1023-
self.ulysses_pg, value_states, self.scatter_idx, self.gather_idx, self.use_sync
1034+
self.ulysses_pg,
1035+
value_states,
1036+
self.scatter_idx,
1037+
self.gather_idx,
1038+
self.use_sync,
10241039
)
10251040

10261041
current_q_len = query_states.shape[1]
@@ -1034,17 +1049,26 @@ def forward(
10341049
# =============================================================
10351050
if self.sp_ring_degree > 1:
10361051
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
1037-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[self.ring_rank].clone()
1052+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
1053+
self.ring_rank
1054+
].clone()
10381055
else:
1039-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[self.ring_rank].clone()
1056+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
1057+
self.ring_rank
1058+
].clone()
10401059

10411060
lck = 0 if cache_hidden is None else len(cache_hidden[0])
10421061

10431062
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
10441063
cos, sin = self.rotary_emb(query_states, position_ids + lck)
10451064
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
10461065
query_states, key_states = apply_multimodal_rotary_pos_emb(
1047-
query_states, key_states, cos, sin, self.config.rope_scaling["mrope_section"], unsqueeze_dim=2
1066+
query_states,
1067+
key_states,
1068+
cos,
1069+
sin,
1070+
self.config.rope_scaling["mrope_section"],
1071+
unsqueeze_dim=2,
10481072
)
10491073
else:
10501074
cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck)
@@ -1087,8 +1111,9 @@ def forward(
10871111
else:
10881112
acc_lse = lse_ring
10891113

1090-
assert acc_lse.shape[1] == current_q_len, \
1091-
f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
1114+
assert (
1115+
acc_lse.shape[1] == current_q_len
1116+
), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
10921117

10931118
acc_out = out_ring
10941119

@@ -1097,7 +1122,13 @@ def forward(
10971122
num_kv_heads_local = cache_k[0].shape[2]
10981123
local_groups = local_num_heads // num_kv_heads_local
10991124

1100-
q_shape_expanded = (bsz, current_q_len, num_kv_heads_local, local_groups, self.head_dim)
1125+
q_shape_expanded = (
1126+
bsz,
1127+
current_q_len,
1128+
num_kv_heads_local,
1129+
local_groups,
1130+
self.head_dim,
1131+
)
11011132
qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D]
11021133

11031134
for i in range(1, len(cache_k)):
@@ -1118,8 +1149,9 @@ def forward(
11181149
# Online Softmax Update
11191150
new_lse = torch.logaddexp(acc_lse, step_lse)
11201151

1121-
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(-1) + \
1122-
step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
1152+
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(
1153+
-1
1154+
) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
11231155

11241156
acc_lse = new_lse
11251157

0 commit comments

Comments
 (0)