@@ -1211,13 +1211,18 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
1211
1211
k , v , return_softmax_lse ):
1212
1212
assert isinstance (prefill , FlashInferPrefillMetadata )
1213
1213
assert prefill .prefill_main is not None
1214
- return prefill .prefill_main .run (
1214
+ ret = prefill .prefill_main .run (
1215
1215
q = q ,
1216
1216
k = k ,
1217
1217
v = v ,
1218
1218
return_lse = return_softmax_lse ,
1219
1219
)
1220
1220
1221
+ if isinstance (ret , tuple ):
1222
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1223
+ return ret [0 ], ret [1 ].transpose (0 , 1 ).contiguous ()
1224
+ return ret
1225
+
1221
1226
def _run_prefill_new_tokens_cudnn (self , prefill : MLACommonPrefillMetadata ,
1222
1227
q , k , v , return_softmax_lse ):
1223
1228
assert isinstance (prefill , CudnnPrefillMetadata )
@@ -1260,12 +1265,14 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
1260
1265
def _run_prefill_context_chunk_fi (self , prefill : MLACommonPrefillMetadata ,
1261
1266
chunk_idx : int , q , k , v ):
1262
1267
assert isinstance (prefill , FlashInferPrefillMetadata )
1263
- return prefill .prefill_chunks [chunk_idx ].run (
1268
+ attn_out , lse = prefill .prefill_chunks [chunk_idx ].run (
1264
1269
q = q ,
1265
1270
k = k ,
1266
1271
v = v ,
1267
1272
return_lse = True ,
1268
1273
)
1274
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1275
+ return attn_out , lse .transpose (0 , 1 ).contiguous ()
1269
1276
1270
1277
def _run_prefill_context_chunk_cudnn (self ,
1271
1278
prefill : MLACommonPrefillMetadata ,
0 commit comments