Skip to content

Commit ea27954

Browse files
committed
TP: Add (slower) SDPA fallback mode when flash-attn is unavailable
1 parent c050aec commit ea27954

File tree

1 file changed

+54
-21
lines changed

1 file changed

+54
-21
lines changed

exllamav2/attn.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ def forward_paged_tp_old(
831831

832832
def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
833833

834+
num_attn_heads = q_states.shape[2]
835+
head_dim = q_states.shape[3]
836+
834837
q_states = q_states.transpose(1, 2)
835838
k_states = k_states.transpose(1, 2)
836839
v_states = v_states.transpose(1, 2)
@@ -881,7 +884,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
881884
attn_output = torch.matmul(attn_weights, v_states)
882885

883886
attn_output = attn_output.transpose(1, 2)
884-
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
887+
attn_output = attn_output.reshape((batch_size, q_len, num_attn_heads * head_dim))
885888
return attn_output
886889

887890

@@ -955,8 +958,10 @@ def forward(self,
955958
loras: list[ExLlamaV2Lora] | None = None,
956959
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
957960

961+
cfg = self.model.config
958962
global has_flash_attn
959963
global has_xformers
964+
use_flash_attn = has_flash_attn and not cfg.no_flash_attn
960965

961966
if isinstance(attn_params, ExLlamaV2Attention.PagedParams):
962967
return self.forward_paged(
@@ -968,7 +973,7 @@ def forward(self,
968973
)
969974

970975
if self.is_tp:
971-
if cache is not None:
976+
if cache is not None and use_flash_attn:
972977
return self.forward_tp(
973978
hidden_states,
974979
cache,
@@ -1002,7 +1007,6 @@ def forward(self,
10021007
**kwargs
10031008
)
10041009

1005-
cfg = self.model.config
10061010
constants = self.model.get_device_context(self.device_idx)
10071011

10081012
batch_size, q_len, _ = hidden_states.shape
@@ -1193,7 +1197,10 @@ def forward_tp_old(
11931197

11941198
assert self.q_handle is not None
11951199
use_flash_attn = has_flash_attn and not cfg.no_flash_attn
1196-
assert use_flash_attn, "Tensor parallel inference requires flash-attn"
1200+
if not use_flash_attn:
1201+
assert has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping, \
1202+
"TP attention without flash-attn must use Torch SDPA with lower-right attention mask " \
1203+
"(use PyTorch 2.4.0+) and does not support logit softcapping."
11971204

11981205
hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = cfg.head_dim)
11991206

@@ -1236,24 +1243,50 @@ def forward_tp_old(
12361243
torch.cuda.set_stream(context.stream)
12371244

12381245
if k_cache is not None:
1239-
attn_output = flash_attn_with_kvcache(
1240-
q = q[idx],
1241-
k = k[idx],
1242-
v = v[idx],
1243-
k_cache = k_cache[idx],
1244-
v_cache = v_cache[idx],
1245-
causal = True,
1246-
softmax_scale = self.scaling,
1247-
cache_seqlens = attn_params.past_len_tp[idx]
1248-
)
1246+
if use_flash_attn:
1247+
attn_output = flash_attn_with_kvcache(
1248+
q = q[idx],
1249+
k = k[idx],
1250+
v = v[idx],
1251+
k_cache = k_cache[idx],
1252+
v_cache = v_cache[idx],
1253+
causal = True,
1254+
softmax_scale = self.scaling,
1255+
cache_seqlens = attn_params.past_len_tp[idx]
1256+
)
1257+
else:
1258+
cache_a = attn_params.past_len
1259+
cache_b = attn_params.past_len + q_len
1260+
k_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(k[idx])
1261+
v_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(v[idx])
1262+
attn_output = self._attn_torch(
1263+
batch_size,
1264+
q_len,
1265+
q[idx],
1266+
k_cache[idx][:batch_size, :cache_b, :, :],
1267+
v_cache[idx][:batch_size, :cache_b, :, :],
1268+
attn_params,
1269+
cfg
1270+
)
12491271
else:
1250-
attn_output = flash_attn_func(
1251-
q[idx],
1252-
k[idx],
1253-
v[idx],
1254-
causal = True,
1255-
softmax_scale=self.scaling,
1256-
)
1272+
if use_flash_attn:
1273+
attn_output = flash_attn_func(
1274+
q[idx],
1275+
k[idx],
1276+
v[idx],
1277+
causal = True,
1278+
softmax_scale = self.scaling,
1279+
)
1280+
else:
1281+
attn_output = self._attn_torch(
1282+
batch_size,
1283+
q_len,
1284+
q[idx],
1285+
k[idx],
1286+
v[idx],
1287+
attn_params,
1288+
cfg
1289+
)
12571290

12581291
attn_output = attn_output.view(batch_size * q_len, (b - a) * cfg.head_dim * cfg.num_key_value_groups)
12591292
attn_outputs.append(attn_output)

0 commit comments

Comments
 (0)