Skip to content

Commit 1df7b04

Browse files
committed
Allow non-causal attn with SDPA
1 parent 0d9adf9 commit 1df7b04

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

exllamav2/attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
840840

841841
# SDPA
842842

843-
if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping:
843+
if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping:
844844

845845
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
846846
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
@@ -849,7 +849,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
849849
k_states = k_states[:, :, -self.sliding_window:, :]
850850
v_states = v_states[:, :, -self.sliding_window:, :]
851851

852-
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
852+
if attn_params.is_causal():
853+
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
854+
else:
855+
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
853856
attn_output = F.scaled_dot_product_attention(
854857
q_states,
855858
k_states,

0 commit comments

Comments
 (0)