Skip to content

Commit 5b8077b

Browse files
gmarinho2gemini-code-assist[bot]maxdebayser
authored
Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <[email protected]> Signed-off-by: Gabriel Marinho <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <[email protected]>
1 parent 038e9be commit 5b8077b

File tree

14 files changed

+101
-102
lines changed

14 files changed

+101
-102
lines changed

vllm/entrypoints/llm.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
5252
get_cached_tokenizer)
5353
from vllm.usage.usage_lib import UsageContext
54-
from vllm.utils import Counter, Device, is_list_of
54+
from vllm.utils import Counter, Device, as_iter, is_list_of
5555
from vllm.v1.sample.logits_processor import LogitsProcessor
5656

5757
if TYPE_CHECKING:
@@ -364,14 +364,6 @@ def generate(
364364
# Use default sampling params.
365365
sampling_params = self.get_default_sampling_params()
366366

367-
tokenization_kwargs: dict[str, Any] = {}
368-
truncate_prompt_tokens = None
369-
if isinstance(sampling_params, SamplingParams):
370-
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
371-
372-
_validate_truncation_size(model_config.max_model_len,
373-
truncate_prompt_tokens, tokenization_kwargs)
374-
375367
# Add any modality specific loras to the corresponding prompts
376368
lora_request = self._get_modality_specific_lora_reqs(
377369
prompts, lora_request)
@@ -381,7 +373,6 @@ def generate(
381373
params=sampling_params,
382374
use_tqdm=use_tqdm,
383375
lora_request=lora_request,
384-
tokenization_kwargs=tokenization_kwargs,
385376
priority=priority,
386377
)
387378

@@ -871,6 +862,8 @@ def encode(
871862
If `False`, no progress bar is created.
872863
lora_request: LoRA request to use for generation, if any.
873864
pooling_task: Override the pooling task to use.
865+
tokenization_kwargs: overrides tokenization_kwargs set in
866+
pooling_params
874867
875868
Returns:
876869
A list of `PoolingRequestOutput` objects containing the
@@ -916,24 +909,17 @@ def encode(
916909
# Use default pooling params.
917910
pooling_params = PoolingParams()
918911

919-
if isinstance(pooling_params, PoolingParams):
920-
pooling_params.verify(pooling_task, model_config)
921-
else:
922-
for pooling_param in pooling_params:
923-
pooling_param.verify(pooling_task, model_config)
924-
925-
if tokenization_kwargs is None:
926-
tokenization_kwargs = dict[str, Any]()
927-
_validate_truncation_size(model_config.max_model_len,
928-
truncate_prompt_tokens,
929-
tokenization_kwargs)
912+
for param in as_iter(pooling_params):
913+
param.verify(pooling_task, model_config)
914+
# for backwards compatibility
915+
if truncate_prompt_tokens is not None:
916+
param.truncate_prompt_tokens = truncate_prompt_tokens
930917

931918
self._validate_and_add_requests(
932919
prompts=prompts,
933920
params=pooling_params,
934921
use_tqdm=use_tqdm,
935922
lora_request=lora_request,
936-
tokenization_kwargs=tokenization_kwargs,
937923
)
938924

939925
outputs = self._run_engine(use_tqdm=use_tqdm)
@@ -1385,7 +1371,6 @@ def _validate_and_add_requests(
13851371
*,
13861372
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
13871373
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1388-
tokenization_kwargs: Optional[dict[str, Any]] = None,
13891374
priority: Optional[list[int]] = None,
13901375
) -> None:
13911376
if isinstance(prompts, (str, dict)):
@@ -1412,7 +1397,17 @@ def _validate_and_add_requests(
14121397
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
14131398
it = tqdm_func(it, desc="Adding requests")
14141399

1400+
model_config = self.llm_engine.model_config
1401+
14151402
for i, prompt in enumerate(it):
1403+
1404+
param = params[i] if isinstance(params, Sequence) else params
1405+
1406+
tokenization_kwargs: dict[str, Any] = {}
1407+
_validate_truncation_size(model_config.max_model_len,
1408+
param.truncate_prompt_tokens,
1409+
tokenization_kwargs)
1410+
14161411
self._add_request(
14171412
prompt,
14181413
params[i] if isinstance(params, Sequence) else params,

vllm/entrypoints/openai/protocol.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
452452
min_tokens: int = 0
453453
skip_special_tokens: bool = True
454454
spaces_between_special_tokens: bool = True
455-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
455+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
456456
prompt_logprobs: Optional[int] = None
457457
allowed_token_ids: Optional[list[int]] = None
458458
bad_words: list[str] = Field(default_factory=list)
@@ -995,7 +995,7 @@ class CompletionRequest(OpenAIBaseModel):
995995
min_tokens: int = 0
996996
skip_special_tokens: bool = True
997997
spaces_between_special_tokens: bool = True
998-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
998+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
999999
allowed_token_ids: Optional[list[int]] = None
10001000
prompt_logprobs: Optional[int] = None
10011001
# --8<-- [end:completion-sampling-params]
@@ -1325,8 +1325,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
13251325
# --8<-- [end:embedding-extra-params]
13261326

13271327
def to_pooling_params(self):
1328-
return PoolingParams(dimensions=self.dimensions,
1329-
normalize=self.normalize)
1328+
return PoolingParams(
1329+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1330+
dimensions=self.dimensions,
1331+
normalize=self.normalize)
13301332

13311333

13321334
class EmbeddingChatRequest(OpenAIBaseModel):
@@ -1393,8 +1395,10 @@ def check_generation_prompt(cls, data):
13931395
return data
13941396

13951397
def to_pooling_params(self):
1396-
return PoolingParams(dimensions=self.dimensions,
1397-
normalize=self.normalize)
1398+
return PoolingParams(
1399+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1400+
dimensions=self.dimensions,
1401+
normalize=self.normalize)
13981402

13991403

14001404
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
@@ -1430,7 +1434,9 @@ class ScoreRequest(OpenAIBaseModel):
14301434
# --8<-- [end:score-extra-params]
14311435

14321436
def to_pooling_params(self):
1433-
return PoolingParams(activation=self.activation)
1437+
return PoolingParams(
1438+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1439+
activation=self.activation)
14341440

14351441

14361442
class RerankRequest(OpenAIBaseModel):
@@ -1460,7 +1466,9 @@ class RerankRequest(OpenAIBaseModel):
14601466
# --8<-- [end:rerank-extra-params]
14611467

14621468
def to_pooling_params(self):
1463-
return PoolingParams(activation=self.activation)
1469+
return PoolingParams(
1470+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1471+
activation=self.activation)
14641472

14651473

14661474
class RerankDocument(BaseModel):
@@ -1618,7 +1626,9 @@ class ClassificationRequest(OpenAIBaseModel):
16181626
# --8<-- [end:classification-extra-params]
16191627

16201628
def to_pooling_params(self):
1621-
return PoolingParams(activation=self.activation)
1629+
return PoolingParams(
1630+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1631+
activation=self.activation)
16221632

16231633

16241634
class ClassificationData(OpenAIBaseModel):

vllm/entrypoints/openai/serving_chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ async def create_chat_completion(
237237
documents=request.documents,
238238
chat_template_kwargs=request.chat_template_kwargs,
239239
tool_parser=tool_parser,
240-
truncate_prompt_tokens=request.truncate_prompt_tokens,
241240
add_special_tokens=request.add_special_tokens,
242241
)
243242
else:

vllm/entrypoints/openai/serving_classification.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ async def _preprocess(
6161
ctx.request,
6262
ctx.tokenizer,
6363
ctx.request.input,
64-
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
6564
)
6665

6766
return None
@@ -157,18 +156,6 @@ async def create_classify(
157156

158157
return await super().handle(ctx) # type: ignore
159158

160-
@override
161-
def _validate_request(
162-
self,
163-
ctx: ClassificationServeContext,
164-
) -> Optional[ErrorResponse]:
165-
if error := super()._validate_request(ctx):
166-
return error
167-
168-
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
169-
170-
return None
171-
172159
@override
173160
def _create_pooling_params(
174161
self,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ async def create_completion(
137137
request,
138138
tokenizer,
139139
request.prompt,
140-
truncate_prompt_tokens=request.truncate_prompt_tokens,
141140
add_special_tokens=request.add_special_tokens,
142141
)
143142
except ValueError as e:

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ async def _preprocess(
9797
# so there is no need to append extra tokens to the input
9898
add_generation_prompt=False,
9999
continue_final_message=False,
100-
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
101100
add_special_tokens=ctx.request.add_special_tokens,
102101
)
103102
else:
@@ -106,7 +105,6 @@ async def _preprocess(
106105
ctx.request,
107106
tokenizer,
108107
ctx.request.input,
109-
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
110108
add_special_tokens=ctx.request.add_special_tokens,
111109
)
112110
return None
@@ -631,18 +629,6 @@ async def create_embedding(
631629

632630
return await super().handle(ctx) # type: ignore
633631

634-
@override
635-
def _validate_request(
636-
self,
637-
ctx: ServeContext[EmbeddingRequest],
638-
) -> Optional[ErrorResponse]:
639-
if error := super()._validate_request(ctx):
640-
return error
641-
642-
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
643-
644-
return None
645-
646632
@override
647633
def _create_pooling_params(
648634
self,

0 commit comments

Comments
 (0)