Skip to content

Commit 6b6e987

Browse files
[NVIDIA] flashinfer TRTLLM attention prefill token limit (#25998)
Signed-off-by: jasonlizhengjian <[email protected]> Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 9c3c21c commit 6b6e987

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

vllm/utils/flashinfer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,18 @@ def use_trtllm_attention(
283283

284284
if force_use_trtllm is None:
285285
# Environment variable not set - use auto-detection
286-
use_trtllm = (
287-
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
288-
)
289-
if use_trtllm:
290-
logger.warning_once("Using TRTLLM attention (auto-detected).")
286+
if is_prefill:
287+
# Prefill auto-detection
288+
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
289+
if use_trtllm:
290+
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
291+
else:
292+
# Decode auto-detection
293+
use_trtllm = (
294+
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
295+
)
296+
if use_trtllm:
297+
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
291298
return use_trtllm
292299

293300
# Environment variable is set to 1 - respect it

0 commit comments

Comments
 (0)