Skip to content

Commit 11a2e13

Browse files
committed
clean up
Signed-off-by: Yu Feng <fengyufengyu@didiglobal.com>
1 parent 19dd2a8 commit 11a2e13

File tree

4 files changed

+86
-31
lines changed

4 files changed

+86
-31
lines changed

scripts/train_eagle3.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
destroy_distributed,
3636
get_dp_group,
3737
get_draft_dp_group,
38+
get_draft_sp_group,
3839
get_tp_group,
39-
init_distributed, get_draft_sp_group,
40+
init_distributed,
4041
)
4142
from specforge.modeling.target import (
4243
Eagle3TargetModel,
@@ -622,11 +623,6 @@ def record_metrcs(
622623
tracker.log(logdict, step=global_step)
623624

624625

625-
import torch
626-
import torch.distributed as dist
627-
import torch.nn.functional as F
628-
629-
630626
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
631627
"""
632628
Process: TP split -> Pad to Max Len -> SP gather.
@@ -648,7 +644,9 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
648644
local_seq_len = local_tp_shard.size(sp_dim)
649645

650646
# Find global max sequence length in SP group
651-
len_tensor = torch.tensor([local_seq_len], device=local_tp_shard.device, dtype=torch.long)
647+
len_tensor = torch.tensor(
648+
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
649+
)
652650
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
653651
max_seq_len = len_tensor.item()
654652

@@ -665,12 +663,16 @@ def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Te
665663
pad_config[pad_idx] = pad_size
666664

667665
# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
668-
local_tp_shard_padded = F.pad(local_tp_shard, pad_config, value=0)
666+
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
669667
else:
670668
local_tp_shard_padded = local_tp_shard
671669

672-
gathered_shards = [torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)]
673-
dist.all_gather(gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group)
670+
gathered_shards = [
671+
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
672+
]
673+
dist.all_gather(
674+
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
675+
)
674676

675677
return torch.cat(gathered_shards, dim=sp_dim)
676678

specforge/layers/ring/ring_flash_attn.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from yunchang.kernels import AttnType, select_flash_attn_impl
3+
24
from .utils import RingComm, update_out_and_lse
3-
from yunchang.kernels import select_flash_attn_impl, AttnType
5+
46

57
def ring_flash_attn_forward(
68
process_group,
@@ -31,7 +33,9 @@ def ring_flash_attn_forward(
3133
comm.commit()
3234

3335
if not causal or step <= comm.rank:
34-
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
36+
fn = select_flash_attn_impl(
37+
attn_type, stage="fwd-only", attn_processor=attn_processor
38+
)
3539
block_out, block_lse = fn(
3640
q,
3741
k,
@@ -219,7 +223,22 @@ def backward(ctx, dout, *args):
219223
deterministic=ctx.deterministic,
220224
attn_type=ctx.attn_type,
221225
)
222-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
226+
return (
227+
dq,
228+
dk,
229+
dv,
230+
None,
231+
None,
232+
None,
233+
None,
234+
None,
235+
None,
236+
None,
237+
None,
238+
None,
239+
None,
240+
None,
241+
)
223242

224243

225244
def ring_flash_attn_qkvpacked_func(

specforge/layers/ring/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
__all__ = ["update_out_and_lse", "RingComm"]
88

9+
910
@torch.jit.script
1011
def _update_out_and_lse(
1112
out: torch.Tensor,
1213
lse: torch.Tensor,
1314
block_out: torch.Tensor,
1415
block_lse: torch.Tensor,
1516
) -> Tuple[torch.Tensor, torch.Tensor]:
16-
17+
1718
block_out = block_out.to(torch.float32)
1819
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
1920

@@ -115,4 +116,4 @@ def wait(self):
115116
for req in self._reqs:
116117
req.wait()
117118
self._reqs = None
118-
self._ops = []
119+
self._ops = []

specforge/modeling/draft/llama3_eagle.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from transformers.activations import ACT2FN
1212
from transformers.cache_utils import Cache
1313
from transformers.models.llama.configuration_llama import LlamaConfig
14-
from yunchang import EXTRACT_FUNC_DICT
1514
from yunchang.comm import SeqAllToAll4D
1615

1716
from specforge.modeling.draft.flex_attention import (
@@ -20,10 +19,10 @@
2019
generate_eagle3_mask,
2120
)
2221
from specforge.utils import print_with_rank
23-
from .base import Eagle3DraftModel
22+
2423
from ...distributed import get_sp_ring_group, get_sp_ulysses_group
2524
from ...layers.ring import ring_flash_attn_func
26-
25+
from .base import Eagle3DraftModel
2726

2827

2928
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -961,6 +960,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
961960
"""
962961
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
963962
"""
963+
964964
def __init__(self, config):
965965
super().__init__(config)
966966
assert (
@@ -996,19 +996,35 @@ def forward(
996996
query_states = self.q_proj(hidden_states)
997997
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
998998
query_states = SeqAllToAll4D.apply(
999-
self.ulysses_pg, query_states, self.scatter_idx, self.gather_idx, self.use_sync
999+
self.ulysses_pg,
1000+
query_states,
1001+
self.scatter_idx,
1002+
self.gather_idx,
1003+
self.use_sync,
10001004
)
10011005

10021006
key_states = self.k_proj(hidden_states)
1003-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1007+
key_states = key_states.view(
1008+
bsz, q_len, self.num_key_value_heads, self.head_dim
1009+
)
10041010
key_states = SeqAllToAll4D.apply(
1005-
self.ulysses_pg, key_states, self.scatter_idx, self.gather_idx, self.use_sync
1011+
self.ulysses_pg,
1012+
key_states,
1013+
self.scatter_idx,
1014+
self.gather_idx,
1015+
self.use_sync,
10061016
)
10071017

10081018
value_states = self.v_proj(hidden_states)
1009-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1019+
value_states = value_states.view(
1020+
bsz, q_len, self.num_key_value_heads, self.head_dim
1021+
)
10101022
value_states = SeqAllToAll4D.apply(
1011-
self.ulysses_pg, value_states, self.scatter_idx, self.gather_idx, self.use_sync
1023+
self.ulysses_pg,
1024+
value_states,
1025+
self.scatter_idx,
1026+
self.gather_idx,
1027+
self.use_sync,
10121028
)
10131029

10141030
current_q_len = query_states.shape[1]
@@ -1022,17 +1038,26 @@ def forward(
10221038
# =============================================================
10231039
if self.sp_ring_degree > 1:
10241040
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
1025-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[self.ring_rank].clone()
1041+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
1042+
self.ring_rank
1043+
].clone()
10261044
else:
1027-
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[self.ring_rank].clone()
1045+
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
1046+
self.ring_rank
1047+
].clone()
10281048

10291049
lck = 0 if cache_hidden is None else len(cache_hidden[0])
10301050

10311051
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
10321052
cos, sin = self.rotary_emb(query_states, position_ids + lck)
10331053
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
10341054
query_states, key_states = apply_multimodal_rotary_pos_emb(
1035-
query_states, key_states, cos, sin, self.config.rope_scaling["mrope_section"], unsqueeze_dim=2
1055+
query_states,
1056+
key_states,
1057+
cos,
1058+
sin,
1059+
self.config.rope_scaling["mrope_section"],
1060+
unsqueeze_dim=2,
10361061
)
10371062
else:
10381063
cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck)
@@ -1075,8 +1100,9 @@ def forward(
10751100
else:
10761101
acc_lse = lse_ring
10771102

1078-
assert acc_lse.shape[1] == current_q_len, \
1079-
f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
1103+
assert (
1104+
acc_lse.shape[1] == current_q_len
1105+
), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}"
10801106

10811107
acc_out = out_ring
10821108

@@ -1085,7 +1111,13 @@ def forward(
10851111
num_kv_heads_local = cache_k[0].shape[2]
10861112
local_groups = local_num_heads // num_kv_heads_local
10871113

1088-
q_shape_expanded = (bsz, current_q_len, num_kv_heads_local, local_groups, self.head_dim)
1114+
q_shape_expanded = (
1115+
bsz,
1116+
current_q_len,
1117+
num_kv_heads_local,
1118+
local_groups,
1119+
self.head_dim,
1120+
)
10891121
qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D]
10901122

10911123
for i in range(1, len(cache_k)):
@@ -1106,8 +1138,9 @@ def forward(
11061138
# Online Softmax Update
11071139
new_lse = torch.logaddexp(acc_lse, step_lse)
11081140

1109-
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(-1) + \
1110-
step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
1141+
acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze(
1142+
-1
1143+
) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1)
11111144

11121145
acc_lse = new_lse
11131146

0 commit comments

Comments
 (0)