Skip to content

Commit 2e630ae

Browse files
committed
Fix alt pos embeddings and block diagonal mask when flash-attn is disabled
1 parent 6e4a84a commit 2e630ae

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

exllamav2/attn.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
882882
k_states = k_states[:, :, -self.sliding_window:, :]
883883
v_states = v_states[:, :, -self.sliding_window:, :]
884884

885-
if self.layer_idx in attn_params.block_diag_layers:
885+
if self.layer_idx in attn_params.block_diag_layers or causal:
886886
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
887+
elif not causal:
888+
attn_mask_lr = None
887889
elif attn_params.is_causal():
888890
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
889891
else:
@@ -892,7 +894,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
892894
q_states,
893895
k_states,
894896
v_states,
895-
attn_mask_lr if causal else None,
897+
attn_mask_lr,
896898
scale = self.scaling
897899
)
898900

@@ -910,10 +912,12 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
910912
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
911913
elif causal:
912914
attn_mask = attn_params.get_attn_mask(attn_weights.device)
915+
else:
916+
attn_mask = None
913917

914918
if cfg.attn_logit_softcapping:
915919
ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping)
916-
if causal and attn_mask is not None:
920+
if attn_mask is not None:
917921
attn_weights = attn_weights + attn_mask
918922
if self.sliding_window and k_states.shape[-1] >= self.sliding_window:
919923
attn_weights = attn_weights[:, :, :, -self.sliding_window:]
@@ -1109,6 +1113,12 @@ def forward(
11091113
offset = attn_params.rope_offsets.cpu().item()
11101114
pass_past_len_1 += offset
11111115

1116+
sc = attn_params.get_alt_rope_embed(self.device_idx)
1117+
if not sc:
1118+
sin, cos = constants.sin, constants.cos
1119+
else:
1120+
sin, cos = sc
1121+
11121122
ext_c.q_attn_forward_1(
11131123
self.q_handle,
11141124
hidden_states,
@@ -1119,8 +1129,8 @@ def forward(
11191129
q_states,
11201130
k_states,
11211131
v_states,
1122-
constants.sin,
1123-
constants.cos,
1132+
sin,
1133+
cos,
11241134
pass_loras,
11251135
pass_lora_temp
11261136
)

exllamav2/attn_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
194194
return None
195195
positions = torch.arange(csl[-1], device = csl.device)
196196
labels = torch.searchsorted(csl[1:], positions, right = True)
197-
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
197+
self.block_diag_mask = torch.where(labels.unsqueeze(0) == labels.unsqueeze(1).repeat(1, self.batch_size), 0, -65504.0).half()
198198
if self.block_diag_mask.device.index != device:
199199
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
200200
return self.block_diag_mask

0 commit comments

Comments
 (0)