Skip to content

Commit 119f006

Browse files
[Renderer] Clean up renderer code (#26216)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent a42d2df commit 119f006

File tree

5 files changed

+94
-134
lines changed

5 files changed

+94
-134
lines changed

tests/entrypoints/openai/test_token_in_token_out.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server):
5454
prompt=token_ids,
5555
max_tokens=20,
5656
temperature=0,
57-
echo=True,
57+
echo=False,
5858
extra_body={
5959
"return_token_ids": True,
6060
},

tests/test_inputs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from vllm.inputs import zip_enc_dec_prompts
7-
from vllm.inputs.parse import parse_and_batch_prompt
7+
from vllm.inputs.parse import parse_raw_prompts
88

99
pytestmark = pytest.mark.cpu_test
1010

@@ -31,30 +31,30 @@
3131
]
3232

3333

34-
def test_parse_single_batch_empty():
34+
def test_parse_raw_single_batch_empty():
3535
with pytest.raises(ValueError, match="at least one prompt"):
36-
parse_and_batch_prompt([])
36+
parse_raw_prompts([])
3737

3838
with pytest.raises(ValueError, match="at least one prompt"):
39-
parse_and_batch_prompt([[]])
39+
parse_raw_prompts([[]])
4040

4141

4242
@pytest.mark.parametrize('string_input', STRING_INPUTS)
43-
def test_parse_single_batch_string_consistent(string_input: str):
44-
assert parse_and_batch_prompt(string_input) \
45-
== parse_and_batch_prompt([string_input])
43+
def test_parse_raw_single_batch_string_consistent(string_input: str):
44+
assert parse_raw_prompts(string_input) \
45+
== parse_raw_prompts([string_input])
4646

4747

4848
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
49-
def test_parse_single_batch_token_consistent(token_input: list[int]):
50-
assert parse_and_batch_prompt(token_input) \
51-
== parse_and_batch_prompt([token_input])
49+
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
50+
assert parse_raw_prompts(token_input) \
51+
== parse_raw_prompts([token_input])
5252

5353

5454
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
55-
def test_parse_single_batch_string_slice(inputs_slice: slice):
56-
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
57-
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
55+
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
56+
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
57+
== parse_raw_prompts(STRING_INPUTS[inputs_slice])
5858

5959

6060
# yapf: disable

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,5 @@ def _build_render_config(
691691
truncate_prompt_tokens=request.truncate_prompt_tokens,
692692
add_special_tokens=request.add_special_tokens,
693693
cache_salt=request.cache_salt,
694-
needs_detokenization=bool(request.echo
695-
and not request.return_token_ids),
694+
needs_detokenization=bool(request.echo),
696695
)

vllm/entrypoints/renderer.py

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
from vllm.config import ModelConfig
1515
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
16+
from vllm.inputs.data import TextPrompt as EngineTextPrompt
1617
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
17-
from vllm.inputs.parse import parse_and_batch_prompt
18+
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
1819
from vllm.transformers_utils.tokenizer import AnyTokenizer
1920
from vllm.utils import AsyncMicrobatchTokenizer
2021

@@ -41,6 +42,27 @@ class RenderConfig:
4142
needs_detokenization: Optional[bool] = False
4243
"""If True, detokenize IDs back to text for inclusion in outputs."""
4344

45+
def verify_truncate_prompt_tokens(
46+
self, model_config: ModelConfig) -> Optional[int]:
47+
"""Validate and normalize `truncate_prompt_tokens` parameter."""
48+
truncate_prompt_tokens = self.truncate_prompt_tokens
49+
if truncate_prompt_tokens is None:
50+
return None
51+
52+
if truncate_prompt_tokens == 0:
53+
return 0
54+
55+
if truncate_prompt_tokens < 0:
56+
truncate_prompt_tokens = model_config.max_model_len
57+
58+
max_length = self.max_length
59+
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
60+
raise ValueError(
61+
f"{truncate_prompt_tokens=} cannot be greater than "
62+
f"{max_length=}. Please select a smaller truncation size.")
63+
64+
return truncate_prompt_tokens
65+
4466

4567
class BaseRenderer(ABC):
4668
"""
@@ -74,7 +96,7 @@ async def render_prompt(
7496
self,
7597
*,
7698
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
77-
config: "RenderConfig",
99+
config: RenderConfig,
78100
) -> list[EngineTokensPrompt]:
79101
"""
80102
Convert text or token inputs into engine-ready TokensPrompt objects.
@@ -107,7 +129,7 @@ async def render_prompt_and_embeds(
107129
prompt_or_prompts: Optional[Union[str, list[str], list[int],
108130
list[list[int]]]] = None,
109131
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
110-
config: "RenderConfig",
132+
config: RenderConfig,
111133
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
112134
"""
113135
Convert text/token and/or base64-encoded embeddings inputs into
@@ -189,62 +211,40 @@ async def render_prompt(
189211
self,
190212
*,
191213
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
192-
config: "RenderConfig",
214+
config: RenderConfig,
193215
) -> list[EngineTokensPrompt]:
194216
"""Implementation of prompt rendering for completion-style requests.
195217
196218
Uses async tokenizer pooling for improved performance. See base class
197219
for detailed parameter documentation.
198220
"""
199-
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
200-
config.truncate_prompt_tokens, config.max_length)
221+
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
222+
self.model_config)
201223
if truncate_prompt_tokens == 0:
202224
return []
203225

204-
# Parse and batch the input prompts
205-
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
206-
207-
tasks = []
208-
for prompt_input in batch_inputs:
209-
if prompt_input["is_tokens"] is True:
210-
# Token input
211-
# Note: detokenization is needed when echo is enabled,
212-
# where the input token IDs are decoded back to text.
213-
task = self._maybe_detokenize(prompt_input["content"],
214-
config.max_length,
215-
truncate_prompt_tokens,
216-
config.cache_salt,
217-
config.needs_detokenization)
218-
else:
219-
# Text input
220-
task = self._tokenize(prompt_input["content"],
221-
config.max_length,
222-
truncate_prompt_tokens,
223-
config.add_special_tokens,
224-
config.cache_salt)
225-
tasks.append(task)
226-
227-
# Wait for all text tokenization to finish
228-
if tasks:
229-
tokenized_text_prompts = await asyncio.gather(*tasks)
230-
return tokenized_text_prompts
231-
232-
return []
226+
tasks = (self._create_prompt(
227+
prompt_input,
228+
config=config,
229+
truncate_prompt_tokens=truncate_prompt_tokens,
230+
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
231+
232+
return await asyncio.gather(*tasks)
233233

234234
async def render_prompt_and_embeds(
235235
self,
236236
*,
237237
prompt_or_prompts: Optional[Union[str, list[str], list[int],
238238
list[list[int]]]] = None,
239239
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
240-
config: "RenderConfig",
240+
config: RenderConfig,
241241
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
242242
"""
243243
Render text/token prompts and/or precomputed embedding prompts. At
244244
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
245245
"""
246-
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
247-
config.truncate_prompt_tokens, config.max_length)
246+
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
247+
self.model_config)
248248
if truncate_prompt_tokens == 0:
249249
return []
250250

@@ -265,29 +265,6 @@ async def render_prompt_and_embeds(
265265

266266
return rendered
267267

268-
def _validate_and_normalize_truncate_tokens(
269-
self,
270-
truncate_prompt_tokens: Optional[int],
271-
max_length: Optional[int],
272-
) -> Optional[int]:
273-
"""Validate and normalize truncate_prompt_tokens parameter."""
274-
if truncate_prompt_tokens is None:
275-
return None
276-
277-
if truncate_prompt_tokens == 0:
278-
return 0
279-
280-
if truncate_prompt_tokens < 0:
281-
truncate_prompt_tokens = self.model_config.max_model_len
282-
283-
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
284-
raise ValueError(
285-
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
286-
f"cannot be greater than max_length ({max_length}). "
287-
f"Please select a smaller truncation size.")
288-
289-
return truncate_prompt_tokens
290-
291268
def _maybe_apply_truncation(
292269
self, token_ids: list[int],
293270
truncate_prompt_tokens: Optional[int]) -> list[int]:
@@ -299,7 +276,38 @@ def _maybe_apply_truncation(
299276

300277
return token_ids[-truncate_prompt_tokens:]
301278

302-
async def _tokenize(
279+
async def _create_prompt(
280+
self,
281+
prompt_input: Union[EngineTextPrompt, EngineTokensPrompt],
282+
config: RenderConfig,
283+
truncate_prompt_tokens: Optional[int],
284+
) -> EngineTokensPrompt:
285+
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
286+
287+
if prompt_token_ids is not None:
288+
# NOTE: detokenization is needed when echo is enabled,
289+
# where the input token IDs are decoded back to text.
290+
return await self._create_prompt_from_token_ids(
291+
prompt_token_ids,
292+
config.max_length,
293+
truncate_prompt_tokens,
294+
config.cache_salt,
295+
config.needs_detokenization,
296+
)
297+
298+
if prompt is not None:
299+
return await self._create_prompt_from_text(
300+
prompt,
301+
config.max_length,
302+
truncate_prompt_tokens,
303+
config.add_special_tokens,
304+
config.cache_salt,
305+
)
306+
307+
# TODO: Also handle embeds prompt using this method
308+
raise NotImplementedError
309+
310+
async def _create_prompt_from_text(
303311
self,
304312
text: str,
305313
max_length: Optional[int],
@@ -330,7 +338,7 @@ async def _tokenize(
330338
return self._create_tokens_prompt(encoded.input_ids, max_length,
331339
cache_salt, text)
332340

333-
async def _maybe_detokenize(
341+
async def _create_prompt_from_token_ids(
334342
self,
335343
token_ids: list[int],
336344
max_length: Optional[int],
@@ -343,7 +351,7 @@ async def _maybe_detokenize(
343351
truncate_prompt_tokens)
344352

345353
prompt = None
346-
if needs_detokenization is True:
354+
if needs_detokenization:
347355
async_tokenizer = self._get_async_tokenizer()
348356
prompt = await async_tokenizer.decode(token_ids)
349357

vllm/inputs/parse.py

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Sequence
44
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
5-
Union, cast, overload)
5+
Union, cast)
66

77
from typing_extensions import TypeIs
88

@@ -16,34 +16,12 @@
1616
import torch
1717

1818

19-
class ParsedText(TypedDict):
20-
content: str
21-
is_tokens: Literal[False]
22-
23-
24-
class ParsedTokens(TypedDict):
25-
content: list[int]
26-
is_tokens: Literal[True]
27-
28-
29-
@overload
30-
def parse_and_batch_prompt(
31-
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
32-
...
33-
34-
35-
@overload
36-
def parse_and_batch_prompt(
37-
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
38-
...
39-
40-
41-
def parse_and_batch_prompt(
19+
def parse_raw_prompts(
4220
prompt: Union[str, list[str], list[int], list[list[int]]],
43-
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
21+
) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
4422
if isinstance(prompt, str):
4523
# case 1: a string
46-
return [ParsedText(content=prompt, is_tokens=False)]
24+
return [TextPrompt(prompt=prompt)]
4725

4826
if isinstance(prompt, list):
4927
if len(prompt) == 0:
@@ -52,24 +30,19 @@ def parse_and_batch_prompt(
5230
if is_list_of(prompt, str):
5331
# case 2: array of strings
5432
prompt = cast(list[str], prompt)
55-
return [
56-
ParsedText(content=elem, is_tokens=False) for elem in prompt
57-
]
33+
return [TextPrompt(prompt=elem) for elem in prompt]
5834
if is_list_of(prompt, int):
5935
# case 3: array of tokens
6036
prompt = cast(list[int], prompt)
61-
return [ParsedTokens(content=prompt, is_tokens=True)]
37+
return [TokensPrompt(prompt_token_ids=prompt)]
6238
if is_list_of(prompt, list):
6339
prompt = cast(list[list[int]], prompt)
6440
if len(prompt[0]) == 0:
6541
raise ValueError("please provide at least one prompt")
6642

6743
if is_list_of(prompt[0], int):
6844
# case 4: array of token arrays
69-
return [
70-
ParsedTokens(content=elem, is_tokens=True)
71-
for elem in prompt
72-
]
45+
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
7346

7447
raise TypeError("prompt must be a string, array of strings, "
7548
"array of tokens, or array of token arrays")
@@ -99,26 +72,6 @@ class ParsedEmbedsPrompt(TypedDict):
9972
ParsedTokensPrompt, ParsedEmbedsPrompt]
10073

10174

102-
@overload
103-
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
104-
...
105-
106-
107-
@overload
108-
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
109-
...
110-
111-
112-
@overload
113-
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
114-
...
115-
116-
117-
@overload
118-
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
119-
...
120-
121-
12275
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
12376
if isinstance(prompt, str):
12477
return ParsedStrPrompt(type="str", content=prompt)

0 commit comments

Comments
 (0)