77import torch .nn .functional as F
88from flash_attn import flash_attn_func
99from torch .nn .attention .flex_attention import create_block_mask , flex_attention
10- from transformers import LlamaConfig
1110from transformers .activations import ACT2FN
1211from transformers .cache_utils import Cache
1312from transformers .models .llama .configuration_llama import LlamaConfig
14- from yunchang import EXTRACT_FUNC_DICT
1513from yunchang .comm import SeqAllToAll4D
16- from flash_attn import flash_attn_func
1714
1815from specforge .modeling .draft .flex_attention import (
1916 compile_friendly_create_block_mask ,
2017 compile_friendly_flex_attention ,
2118 generate_eagle3_mask ,
2219)
2320from specforge .utils import print_with_rank
24- from . base import Eagle3DraftModel
21+
2522from ...distributed import get_sp_ring_group , get_sp_ulysses_group
2623from ...layers .ring import ring_flash_attn_func
27-
24+ from . base import Eagle3DraftModel
2825
2926
3027# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -962,6 +959,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
962959 """
963960 LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
964961 """
962+
965963 def __init__ (self , config ):
966964 super ().__init__ (config )
967965 assert (
@@ -997,19 +995,35 @@ def forward(
997995 query_states = self .q_proj (hidden_states )
998996 query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
999997 query_states = SeqAllToAll4D .apply (
1000- self .ulysses_pg , query_states , self .scatter_idx , self .gather_idx , self .use_sync
998+ self .ulysses_pg ,
999+ query_states ,
1000+ self .scatter_idx ,
1001+ self .gather_idx ,
1002+ self .use_sync ,
10011003 )
10021004
10031005 key_states = self .k_proj (hidden_states )
1004- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1006+ key_states = key_states .view (
1007+ bsz , q_len , self .num_key_value_heads , self .head_dim
1008+ )
10051009 key_states = SeqAllToAll4D .apply (
1006- self .ulysses_pg , key_states , self .scatter_idx , self .gather_idx , self .use_sync
1010+ self .ulysses_pg ,
1011+ key_states ,
1012+ self .scatter_idx ,
1013+ self .gather_idx ,
1014+ self .use_sync ,
10071015 )
10081016
10091017 value_states = self .v_proj (hidden_states )
1010- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1018+ value_states = value_states .view (
1019+ bsz , q_len , self .num_key_value_heads , self .head_dim
1020+ )
10111021 value_states = SeqAllToAll4D .apply (
1012- self .ulysses_pg , value_states , self .scatter_idx , self .gather_idx , self .use_sync
1022+ self .ulysses_pg ,
1023+ value_states ,
1024+ self .scatter_idx ,
1025+ self .gather_idx ,
1026+ self .use_sync ,
10131027 )
10141028
10151029 current_q_len = query_states .shape [1 ]
@@ -1023,17 +1037,26 @@ def forward(
10231037 # =============================================================
10241038 if self .sp_ring_degree > 1 :
10251039 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
1026- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[self .ring_rank ].clone ()
1040+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[
1041+ self .ring_rank
1042+ ].clone ()
10271043 else :
1028- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[self .ring_rank ].clone ()
1044+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[
1045+ self .ring_rank
1046+ ].clone ()
10291047
10301048 lck = 0 if cache_hidden is None else len (cache_hidden [0 ])
10311049
10321050 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
10331051 cos , sin = self .rotary_emb (query_states , position_ids + lck )
10341052 cos , sin = cos .to (query_states .device ), sin .to (query_states .device )
10351053 query_states , key_states = apply_multimodal_rotary_pos_emb (
1036- query_states , key_states , cos , sin , self .config .rope_scaling ["mrope_section" ], unsqueeze_dim = 2
1054+ query_states ,
1055+ key_states ,
1056+ cos ,
1057+ sin ,
1058+ self .config .rope_scaling ["mrope_section" ],
1059+ unsqueeze_dim = 2 ,
10371060 )
10381061 else :
10391062 cos , sin = self .rotary_emb (query_states , seq_len = global_q_len + lck )
@@ -1076,8 +1099,9 @@ def forward(
10761099 else :
10771100 acc_lse = lse_ring
10781101
1079- assert acc_lse .shape [1 ] == current_q_len , \
1080- f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
1102+ assert (
1103+ acc_lse .shape [1 ] == current_q_len
1104+ ), f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
10811105
10821106 acc_out = out_ring
10831107
@@ -1086,7 +1110,13 @@ def forward(
10861110 num_kv_heads_local = cache_k [0 ].shape [2 ]
10871111 local_groups = local_num_heads // num_kv_heads_local
10881112
1089- q_shape_expanded = (bsz , current_q_len , num_kv_heads_local , local_groups , self .head_dim )
1113+ q_shape_expanded = (
1114+ bsz ,
1115+ current_q_len ,
1116+ num_kv_heads_local ,
1117+ local_groups ,
1118+ self .head_dim ,
1119+ )
10901120 qi_reshaped = query_states .view (q_shape_expanded ) # [B, S, KV, G, D]
10911121
10921122 for i in range (1 , len (cache_k )):
@@ -1107,8 +1137,9 @@ def forward(
11071137 # Online Softmax Update
11081138 new_lse = torch .logaddexp (acc_lse , step_lse )
11091139
1110- acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (- 1 ) + \
1111- step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
1140+ acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (
1141+ - 1
1142+ ) + step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
11121143
11131144 acc_lse = new_lse
11141145
0 commit comments