@@ -882,8 +882,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
882
882
k_states = k_states [:, :, - self .sliding_window :, :]
883
883
v_states = v_states [:, :, - self .sliding_window :, :]
884
884
885
- if self .layer_idx in attn_params .block_diag_layers :
885
+ if self .layer_idx in attn_params .block_diag_layers or causal :
886
886
attn_mask_lr = attn_params .get_block_diag_mask (q_states .device )
887
+ elif not causal :
888
+ attn_mask_lr = None
887
889
elif attn_params .is_causal ():
888
890
attn_mask_lr = causal_lower_right (q_len , k_states .shape [2 ])
889
891
else :
@@ -892,7 +894,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
892
894
q_states ,
893
895
k_states ,
894
896
v_states ,
895
- attn_mask_lr if causal else None ,
897
+ attn_mask_lr ,
896
898
scale = self .scaling
897
899
)
898
900
@@ -910,10 +912,12 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
910
912
attn_mask = attn_params .get_block_diag_mask (attn_weights .device )
911
913
elif causal :
912
914
attn_mask = attn_params .get_attn_mask (attn_weights .device )
915
+ else :
916
+ attn_mask = None
913
917
914
918
if cfg .attn_logit_softcapping :
915
919
ext_c .softcap_ (attn_weights , cfg .attn_logit_softcapping )
916
- if causal and attn_mask is not None :
920
+ if attn_mask is not None :
917
921
attn_weights = attn_weights + attn_mask
918
922
if self .sliding_window and k_states .shape [- 1 ] >= self .sliding_window :
919
923
attn_weights = attn_weights [:, :, :, - self .sliding_window :]
@@ -1109,6 +1113,12 @@ def forward(
1109
1113
offset = attn_params .rope_offsets .cpu ().item ()
1110
1114
pass_past_len_1 += offset
1111
1115
1116
+ sc = attn_params .get_alt_rope_embed (self .device_idx )
1117
+ if not sc :
1118
+ sin , cos = constants .sin , constants .cos
1119
+ else :
1120
+ sin , cos = sc
1121
+
1112
1122
ext_c .q_attn_forward_1 (
1113
1123
self .q_handle ,
1114
1124
hidden_states ,
@@ -1119,8 +1129,8 @@ def forward(
1119
1129
q_states ,
1120
1130
k_states ,
1121
1131
v_states ,
1122
- constants . sin ,
1123
- constants . cos ,
1132
+ sin ,
1133
+ cos ,
1124
1134
pass_loras ,
1125
1135
pass_lora_temp
1126
1136
)
0 commit comments