Skip to content

Commit af473f0

Browse files
authored
[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)
Signed-off-by: Po-Han Huang <[email protected]>
1 parent 157f9c1 commit af473f0

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,22 @@
66

77

88
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
9-
from flashinfer import next_positive_power_of_2
10-
11-
# Guess tokens per expert assuming perfect expert distribution first.
12-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
13-
# And pad the number to the next power of 2.
14-
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
15-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
16-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
9+
10+
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
11+
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
12+
# with the necessary kernels is released.
13+
tile_tokens_dim = 8
14+
15+
# from flashinfer import next_positive_power_of_2
16+
17+
# # Guess tokens per expert assuming perfect expert distribution first.
18+
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
19+
# # And pad the number to the next power of 2.
20+
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
21+
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
22+
# # kernel.
23+
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
24+
1725
return tile_tokens_dim
1826

1927

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ def build(self,
524524
head_dim = self.kv_cache_spec.head_size
525525

526526
# currently prefill trtllm attention does not support fp8 kv cache
527-
prefill_use_trtllm = use_trtllm_attention(
527+
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
528+
and use_trtllm_attention(
528529
num_prefill_tokens, max_seq_len, cache_dtype,
529530
num_qo_heads, num_kv_heads, head_dim)
530531
decode_use_trtllm = use_trtllm_attention(

0 commit comments

Comments
 (0)