88import torch .nn .functional as F
99from flash_attn import flash_attn_func
1010from torch .nn .attention .flex_attention import create_block_mask , flex_attention
11- from transformers import LlamaConfig
1211from transformers .activations import ACT2FN
1312from transformers .cache_utils import Cache
1413from transformers .models .llama .configuration_llama import LlamaConfig
15- from yunchang import EXTRACT_FUNC_DICT
1614from yunchang .comm import SeqAllToAll4D
1715
1816from specforge .modeling .draft .flex_attention import (
2119 generate_eagle3_mask ,
2220)
2321from specforge .utils import print_with_rank
24- from . base import Eagle3DraftModel
22+
2523from ...distributed import get_sp_ring_group , get_sp_ulysses_group
2624from ...layers .ring import ring_flash_attn_func
27-
25+ from . base import Eagle3DraftModel
2826
2927try :
3028 from flash_attn import flash_attn_func
@@ -973,6 +971,7 @@ class LlamaUSPFlashAttention(LlamaAttention):
973971 """
974972 LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
975973 """
974+
976975 def __init__ (self , config ):
977976 super ().__init__ (config )
978977 assert (
@@ -1008,19 +1007,35 @@ def forward(
10081007 query_states = self .q_proj (hidden_states )
10091008 query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
10101009 query_states = SeqAllToAll4D .apply (
1011- self .ulysses_pg , query_states , self .scatter_idx , self .gather_idx , self .use_sync
1010+ self .ulysses_pg ,
1011+ query_states ,
1012+ self .scatter_idx ,
1013+ self .gather_idx ,
1014+ self .use_sync ,
10121015 )
10131016
10141017 key_states = self .k_proj (hidden_states )
1015- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1018+ key_states = key_states .view (
1019+ bsz , q_len , self .num_key_value_heads , self .head_dim
1020+ )
10161021 key_states = SeqAllToAll4D .apply (
1017- self .ulysses_pg , key_states , self .scatter_idx , self .gather_idx , self .use_sync
1022+ self .ulysses_pg ,
1023+ key_states ,
1024+ self .scatter_idx ,
1025+ self .gather_idx ,
1026+ self .use_sync ,
10181027 )
10191028
10201029 value_states = self .v_proj (hidden_states )
1021- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
1030+ value_states = value_states .view (
1031+ bsz , q_len , self .num_key_value_heads , self .head_dim
1032+ )
10221033 value_states = SeqAllToAll4D .apply (
1023- self .ulysses_pg , value_states , self .scatter_idx , self .gather_idx , self .use_sync
1034+ self .ulysses_pg ,
1035+ value_states ,
1036+ self .scatter_idx ,
1037+ self .gather_idx ,
1038+ self .use_sync ,
10241039 )
10251040
10261041 current_q_len = query_states .shape [1 ]
@@ -1034,17 +1049,26 @@ def forward(
10341049 # =============================================================
10351050 if self .sp_ring_degree > 1 :
10361051 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
1037- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[self .ring_rank ].clone ()
1052+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 2 )[
1053+ self .ring_rank
1054+ ].clone ()
10381055 else :
1039- position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[self .ring_rank ].clone ()
1056+ position_ids = position_ids .chunk (self .sp_ring_degree , dim = 1 )[
1057+ self .ring_rank
1058+ ].clone ()
10401059
10411060 lck = 0 if cache_hidden is None else len (cache_hidden [0 ])
10421061
10431062 if isinstance (self .rotary_emb , LlamaMutiRotaryEmbedding ):
10441063 cos , sin = self .rotary_emb (query_states , position_ids + lck )
10451064 cos , sin = cos .to (query_states .device ), sin .to (query_states .device )
10461065 query_states , key_states = apply_multimodal_rotary_pos_emb (
1047- query_states , key_states , cos , sin , self .config .rope_scaling ["mrope_section" ], unsqueeze_dim = 2
1066+ query_states ,
1067+ key_states ,
1068+ cos ,
1069+ sin ,
1070+ self .config .rope_scaling ["mrope_section" ],
1071+ unsqueeze_dim = 2 ,
10481072 )
10491073 else :
10501074 cos , sin = self .rotary_emb (query_states , seq_len = global_q_len + lck )
@@ -1087,8 +1111,9 @@ def forward(
10871111 else :
10881112 acc_lse = lse_ring
10891113
1090- assert acc_lse .shape [1 ] == current_q_len , \
1091- f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
1114+ assert (
1115+ acc_lse .shape [1 ] == current_q_len
1116+ ), f"LSE seq_len { acc_lse .shape [1 ]} mismatch with Query seq_len { current_q_len } "
10921117
10931118 acc_out = out_ring
10941119
@@ -1097,7 +1122,13 @@ def forward(
10971122 num_kv_heads_local = cache_k [0 ].shape [2 ]
10981123 local_groups = local_num_heads // num_kv_heads_local
10991124
1100- q_shape_expanded = (bsz , current_q_len , num_kv_heads_local , local_groups , self .head_dim )
1125+ q_shape_expanded = (
1126+ bsz ,
1127+ current_q_len ,
1128+ num_kv_heads_local ,
1129+ local_groups ,
1130+ self .head_dim ,
1131+ )
11011132 qi_reshaped = query_states .view (q_shape_expanded ) # [B, S, KV, G, D]
11021133
11031134 for i in range (1 , len (cache_k )):
@@ -1118,8 +1149,9 @@ def forward(
11181149 # Online Softmax Update
11191150 new_lse = torch .logaddexp (acc_lse , step_lse )
11201151
1121- acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (- 1 ) + \
1122- step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
1152+ acc_out = acc_out * torch .exp (acc_lse - new_lse ).unsqueeze (
1153+ - 1
1154+ ) + step_out * torch .exp (step_lse - new_lse ).unsqueeze (- 1 )
11231155
11241156 acc_lse = new_lse
11251157
0 commit comments