Skip to content

Commit 2a54824

Browse files
committed
Reject requests with logprob when using fast prefill
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 72c0780 commit 2a54824

File tree

3 files changed

+11
-28
lines changed

3 files changed

+11
-28
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class CommonAttentionMetadata:
6969

7070
logits_indices_padded: Optional[torch.Tensor] = None
7171
num_logits_indices: Optional[int] = None
72-
prompt_logprobs: Optional[bool] = None
7372

7473
causal: bool = True
7574

@@ -837,25 +836,13 @@ def build(self,
837836
common_prefix_len: int,
838837
common_attn_metadata: CommonAttentionMetadata,
839838
fast_build: bool = False) -> AttentionMetadata:
840-
# Either not set (None) or prompt_logprobs is False
841-
if not common_attn_metadata.prompt_logprobs:
842-
# Fast prefill path
843-
new_common_attn_metadata =\
844-
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
845-
metadata = super(self.__class__,
846-
self).build(common_prefix_len,
847-
new_common_attn_metadata, fast_build)
848-
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
849-
metadata, common_attn_metadata)
850-
851-
# Default path:
852-
# Either --kv-sharing-fast-prefill is not set or at least one request
853-
# in the current scheduling round requests logprobs for prompt tokens
854-
# which is not compatible with fast prefill
839+
new_common_attn_metadata =\
840+
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
855841
metadata = super(self.__class__,
856-
self).build(common_prefix_len, common_attn_metadata,
857-
fast_build)
858-
return metadata
842+
self).build(common_prefix_len,
843+
new_common_attn_metadata, fast_build)
844+
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
845+
metadata, common_attn_metadata)
859846

860847
# Dynamically create a new attention backend that wraps the
861848
# underlying attention backend but applies

vllm/v1/engine/async_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ async def generate(
335335
returning the RequestOutput back to the caller.
336336
"""
337337

338+
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
339+
and sampling_params.prompt_logprobs):
340+
raise ValueError(
341+
"Fast prefill produces incorrect logprobs for prompt tokens")
342+
338343
try:
339344
# We start the output_handler on the first call to generate() so
340345
# we can call __init__ before the event loop, which enables us

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -846,14 +846,6 @@ def _prepare_inputs(
846846
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
847847
spec_decode_common_attn_metadata = None
848848

849-
if (self.cache_config.kv_sharing_fast_prefill
850-
and self.input_batch.num_prompt_logprobs):
851-
logger.warning(
852-
"Encountered at least one request with prompt_logprobs set "
853-
"with --kv-sharing-fast-prefill enabled. Fast prefill doesn't "
854-
"produce correct logits for prompt tokens, so fast prefill will"
855-
" be disabled for this iteration.")
856-
857849
# Prepare the attention metadata for each KV cache group and make layers
858850
# in the same group share the same metadata.
859851
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -901,7 +893,6 @@ def _prepare_inputs(
901893
slot_mapping=slot_mapping,
902894
logits_indices_padded=logits_indices_padded,
903895
num_logits_indices=logits_indices.size(0),
904-
prompt_logprobs=len(self.input_batch.num_prompt_logprobs) > 0,
905896
causal=True,
906897
)
907898

0 commit comments

Comments
 (0)