@@ -831,6 +831,9 @@ def forward_paged_tp_old(
831
831
832
832
def _attn_torch (self , batch_size , q_len , q_states , k_states , v_states , attn_params , cfg ):
833
833
834
+ num_attn_heads = q_states .shape [2 ]
835
+ head_dim = q_states .shape [3 ]
836
+
834
837
q_states = q_states .transpose (1 , 2 )
835
838
k_states = k_states .transpose (1 , 2 )
836
839
v_states = v_states .transpose (1 , 2 )
@@ -881,7 +884,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
881
884
attn_output = torch .matmul (attn_weights , v_states )
882
885
883
886
attn_output = attn_output .transpose (1 , 2 )
884
- attn_output = attn_output .reshape ((batch_size , q_len , cfg . num_attention_heads * cfg . head_dim ))
887
+ attn_output = attn_output .reshape ((batch_size , q_len , num_attn_heads * head_dim ))
885
888
return attn_output
886
889
887
890
@@ -955,8 +958,10 @@ def forward(self,
955
958
loras : list [ExLlamaV2Lora ] | None = None ,
956
959
** kwargs ) -> torch .Tensor | dict [str : torch .Tensor ]:
957
960
961
+ cfg = self .model .config
958
962
global has_flash_attn
959
963
global has_xformers
964
+ use_flash_attn = has_flash_attn and not cfg .no_flash_attn
960
965
961
966
if isinstance (attn_params , ExLlamaV2Attention .PagedParams ):
962
967
return self .forward_paged (
@@ -968,7 +973,7 @@ def forward(self,
968
973
)
969
974
970
975
if self .is_tp :
971
- if cache is not None :
976
+ if cache is not None and use_flash_attn :
972
977
return self .forward_tp (
973
978
hidden_states ,
974
979
cache ,
@@ -1002,7 +1007,6 @@ def forward(self,
1002
1007
** kwargs
1003
1008
)
1004
1009
1005
- cfg = self .model .config
1006
1010
constants = self .model .get_device_context (self .device_idx )
1007
1011
1008
1012
batch_size , q_len , _ = hidden_states .shape
@@ -1193,7 +1197,10 @@ def forward_tp_old(
1193
1197
1194
1198
assert self .q_handle is not None
1195
1199
use_flash_attn = has_flash_attn and not cfg .no_flash_attn
1196
- assert use_flash_attn , "Tensor parallel inference requires flash-attn"
1200
+ if not use_flash_attn :
1201
+ assert has_lower_right_sdpa and attn_params .is_causal () and not cfg .no_sdpa and not cfg .attn_logit_softcapping , \
1202
+ "TP attention without flash-attn must use Torch SDPA with lower-right attention mask " \
1203
+ "(use PyTorch 2.4.0+) and does not support logit softcapping."
1197
1204
1198
1205
hidden_states = self .model .tp_context .broadcast (0 , hidden_states , BROADCAST_KV , dim = cfg .head_dim )
1199
1206
@@ -1236,24 +1243,50 @@ def forward_tp_old(
1236
1243
torch .cuda .set_stream (context .stream )
1237
1244
1238
1245
if k_cache is not None :
1239
- attn_output = flash_attn_with_kvcache (
1240
- q = q [idx ],
1241
- k = k [idx ],
1242
- v = v [idx ],
1243
- k_cache = k_cache [idx ],
1244
- v_cache = v_cache [idx ],
1245
- causal = True ,
1246
- softmax_scale = self .scaling ,
1247
- cache_seqlens = attn_params .past_len_tp [idx ]
1248
- )
1246
+ if use_flash_attn :
1247
+ attn_output = flash_attn_with_kvcache (
1248
+ q = q [idx ],
1249
+ k = k [idx ],
1250
+ v = v [idx ],
1251
+ k_cache = k_cache [idx ],
1252
+ v_cache = v_cache [idx ],
1253
+ causal = True ,
1254
+ softmax_scale = self .scaling ,
1255
+ cache_seqlens = attn_params .past_len_tp [idx ]
1256
+ )
1257
+ else :
1258
+ cache_a = attn_params .past_len
1259
+ cache_b = attn_params .past_len + q_len
1260
+ k_cache [idx ][:batch_size , cache_a :cache_b , :, :].copy_ (k [idx ])
1261
+ v_cache [idx ][:batch_size , cache_a :cache_b , :, :].copy_ (v [idx ])
1262
+ attn_output = self ._attn_torch (
1263
+ batch_size ,
1264
+ q_len ,
1265
+ q [idx ],
1266
+ k_cache [idx ][:batch_size , :cache_b , :, :],
1267
+ v_cache [idx ][:batch_size , :cache_b , :, :],
1268
+ attn_params ,
1269
+ cfg
1270
+ )
1249
1271
else :
1250
- attn_output = flash_attn_func (
1251
- q [idx ],
1252
- k [idx ],
1253
- v [idx ],
1254
- causal = True ,
1255
- softmax_scale = self .scaling ,
1256
- )
1272
+ if use_flash_attn :
1273
+ attn_output = flash_attn_func (
1274
+ q [idx ],
1275
+ k [idx ],
1276
+ v [idx ],
1277
+ causal = True ,
1278
+ softmax_scale = self .scaling ,
1279
+ )
1280
+ else :
1281
+ attn_output = self ._attn_torch (
1282
+ batch_size ,
1283
+ q_len ,
1284
+ q [idx ],
1285
+ k [idx ],
1286
+ v [idx ],
1287
+ attn_params ,
1288
+ cfg
1289
+ )
1257
1290
1258
1291
attn_output = attn_output .view (batch_size * q_len , (b - a ) * cfg .head_dim * cfg .num_key_value_groups )
1259
1292
attn_outputs .append (attn_output )
0 commit comments