1111from transformers .activations import ACT2FN
1212from transformers .cache_utils import Cache
1313from transformers .models .llama .configuration_llama import LlamaConfig
14- from yunchang import EXTRACT_FUNC_DICT
1514from yunchang .comm import SeqAllToAll4D
1615
1716from specforge .modeling .draft .flex_attention import (
2019 generate_eagle3_mask ,
2120)
2221from specforge .utils import print_with_rank
23- from . base import Eagle3DraftModel
22+
2423from ...distributed import get_sp_ring_group , get_sp_ulysses_group
2524from ...layers .ring import ring_flash_attn_func
26-
25+ from . base import Eagle3DraftModel
2726
2827
2928# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -961,6 +960,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
961960 """
962961 LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
963962 """
963+
964964 def __init__ (self , config ):
965965 super ().__init__ (config )
966966 assert (
@@ -996,19 +996,35 @@ def forward(
996996 query_states = self .q_proj (hidden_states )
997997 query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
998998 query_states = SeqAllToAll4D .apply (
999- self .ulysses_pg , query_states , self .scatter_idx , self .gather_idx , self .use_sync
999+ self .ulysses_pg ,
1000+ query_states ,
1001+ self .scatter_idx ,
1002+ self .gather_idx ,
1003+ self .use_sync ,
10001004 )
10011005
10021006 key_states = self .k_proj (hidden_states )
1003- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1007+ key_states = key_states .view (
1008+ bsz , q_len , self .num_key_value_heads , self .head_dim
1009+ )
10041010 key_states = SeqAllToAll4D .apply (
1005- self .ulysses_pg , key_states , self .scatter_idx , self .gather_idx , self .use_sync
1011+ self .ulysses_pg ,
1012+ key_states ,
1013+ self .scatter_idx ,
1014+ self .gather_idx ,
1015+ self .use_sync ,
10061016 )
10071017
10081018 value_states = self .v_proj (hidden_states )
1009- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1019+ value_states = value_states .view (
1020+ bsz , q_len , self .num_key_value_heads , self .head_dim
1021+ )
10101022 value_states = SeqAllToAll4D .apply (
1011- self .ulysses_pg , value_states , self .scatter_idx , self .gather_idx , self .use_sync
1023+ self .ulysses_pg ,
1024+ value_states ,
1025+ self .scatter_idx ,
1026+ self .gather_idx ,
1027+ self .use_sync ,
10121028 )
10131029
10141030 current_q_len = query_states .shape [1 ]
@@ -1022,17 +1038,26 @@ def forward(
10221038 # =============================================================
10231039 if self .sp_ring_degree > 1 :
10241040 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
1025- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[self .ring_rank ].clone ()
1041+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[
1042+ self .ring_rank
1043+ ].clone ()
10261044 else :
1027- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[self .ring_rank ].clone ()
1045+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[
1046+ self .ring_rank
1047+ ].clone ()
10281048
10291049 lck = 0 if cache_hidden is None else len (cache_hidden [0 ])
10301050
10311051 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
10321052 cos , sin = self .rotary_emb (query_states , position_ids + lck )
10331053 cos , sin = cos .to (query_states .device ), sin .to (query_states .device )
10341054 query_states , key_states = apply_multimodal_rotary_pos_emb (
1035- query_states , key_states , cos , sin , self .config .rope_scaling ["mrope_section" ], unsqueeze_dim = 2
1055+ query_states ,
1056+ key_states ,
1057+ cos ,
1058+ sin ,
1059+ self .config .rope_scaling ["mrope_section" ],
1060+ unsqueeze_dim = 2 ,
10361061 )
10371062 else :
10381063 cos , sin = self .rotary_emb (query_states , seq_len = global_q_len + lck )
@@ -1075,8 +1100,9 @@ def forward(
10751100 else :
10761101 acc_lse = lse_ring
10771102
1078- assert acc_lse .shape [1 ] == current_q_len , \
1079- f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
1103+ assert (
1104+ acc_lse .shape [1 ] == current_q_len
1105+ ), f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
10801106
10811107 acc_out = out_ring
10821108
@@ -1085,7 +1111,13 @@ def forward(
10851111 num_kv_heads_local = cache_k [0 ].shape [2 ]
10861112 local_groups = local_num_heads // num_kv_heads_local
10871113
1088- q_shape_expanded = (bsz , current_q_len , num_kv_heads_local , local_groups , self .head_dim )
1114+ q_shape_expanded = (
1115+ bsz ,
1116+ current_q_len ,
1117+ num_kv_heads_local ,
1118+ local_groups ,
1119+ self .head_dim ,
1120+ )
10891121 qi_reshaped = query_states .view (q_shape_expanded ) # [B, S, KV, G, D]
10901122
10911123 for i in range (1 , len (cache_k )):
@@ -1106,8 +1138,9 @@ def forward(
11061138 # Online Softmax Update
11071139 new_lse = torch .logaddexp (acc_lse , step_lse )
11081140
1109- acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (- 1 ) + \
1110- step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
1141+ acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (
1142+ - 1
1143+ ) + step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
11111144
11121145 acc_lse = new_lse
11131146
0 commit comments