13
13
14
14
from vllm .config import ModelConfig
15
15
from vllm .inputs .data import EmbedsPrompt as EngineEmbedsPrompt
16
+ from vllm .inputs .data import TextPrompt as EngineTextPrompt
16
17
from vllm .inputs .data import TokensPrompt as EngineTokensPrompt
17
- from vllm .inputs .parse import parse_and_batch_prompt
18
+ from vllm .inputs .parse import get_prompt_components , parse_raw_prompts
18
19
from vllm .transformers_utils .tokenizer import AnyTokenizer
19
20
from vllm .utils import AsyncMicrobatchTokenizer
20
21
@@ -41,6 +42,27 @@ class RenderConfig:
41
42
needs_detokenization : Optional [bool ] = False
42
43
"""If True, detokenize IDs back to text for inclusion in outputs."""
43
44
45
+ def verify_truncate_prompt_tokens (
46
+ self , model_config : ModelConfig ) -> Optional [int ]:
47
+ """Validate and normalize `truncate_prompt_tokens` parameter."""
48
+ truncate_prompt_tokens = self .truncate_prompt_tokens
49
+ if truncate_prompt_tokens is None :
50
+ return None
51
+
52
+ if truncate_prompt_tokens == 0 :
53
+ return 0
54
+
55
+ if truncate_prompt_tokens < 0 :
56
+ truncate_prompt_tokens = model_config .max_model_len
57
+
58
+ max_length = self .max_length
59
+ if max_length is not None and truncate_prompt_tokens > max_length : # type: ignore[operator]
60
+ raise ValueError (
61
+ f"{ truncate_prompt_tokens = } cannot be greater than "
62
+ f"{ max_length = } . Please select a smaller truncation size." )
63
+
64
+ return truncate_prompt_tokens
65
+
44
66
45
67
class BaseRenderer (ABC ):
46
68
"""
@@ -74,7 +96,7 @@ async def render_prompt(
74
96
self ,
75
97
* ,
76
98
prompt_or_prompts : Union [str , list [str ], list [int ], list [list [int ]]],
77
- config : " RenderConfig" ,
99
+ config : RenderConfig ,
78
100
) -> list [EngineTokensPrompt ]:
79
101
"""
80
102
Convert text or token inputs into engine-ready TokensPrompt objects.
@@ -107,7 +129,7 @@ async def render_prompt_and_embeds(
107
129
prompt_or_prompts : Optional [Union [str , list [str ], list [int ],
108
130
list [list [int ]]]] = None ,
109
131
prompt_embeds : Optional [Union [bytes , list [bytes ]]] = None ,
110
- config : " RenderConfig" ,
132
+ config : RenderConfig ,
111
133
) -> list [Union [EngineTokensPrompt , EngineEmbedsPrompt ]]:
112
134
"""
113
135
Convert text/token and/or base64-encoded embeddings inputs into
@@ -189,62 +211,40 @@ async def render_prompt(
189
211
self ,
190
212
* ,
191
213
prompt_or_prompts : Union [str , list [str ], list [int ], list [list [int ]]],
192
- config : " RenderConfig" ,
214
+ config : RenderConfig ,
193
215
) -> list [EngineTokensPrompt ]:
194
216
"""Implementation of prompt rendering for completion-style requests.
195
217
196
218
Uses async tokenizer pooling for improved performance. See base class
197
219
for detailed parameter documentation.
198
220
"""
199
- truncate_prompt_tokens = self . _validate_and_normalize_truncate_tokens (
200
- config . truncate_prompt_tokens , config . max_length )
221
+ truncate_prompt_tokens = config . verify_truncate_prompt_tokens (
222
+ self . model_config )
201
223
if truncate_prompt_tokens == 0 :
202
224
return []
203
225
204
- # Parse and batch the input prompts
205
- batch_inputs = parse_and_batch_prompt (prompt_or_prompts )
206
-
207
- tasks = []
208
- for prompt_input in batch_inputs :
209
- if prompt_input ["is_tokens" ] is True :
210
- # Token input
211
- # Note: detokenization is needed when echo is enabled,
212
- # where the input token IDs are decoded back to text.
213
- task = self ._maybe_detokenize (prompt_input ["content" ],
214
- config .max_length ,
215
- truncate_prompt_tokens ,
216
- config .cache_salt ,
217
- config .needs_detokenization )
218
- else :
219
- # Text input
220
- task = self ._tokenize (prompt_input ["content" ],
221
- config .max_length ,
222
- truncate_prompt_tokens ,
223
- config .add_special_tokens ,
224
- config .cache_salt )
225
- tasks .append (task )
226
-
227
- # Wait for all text tokenization to finish
228
- if tasks :
229
- tokenized_text_prompts = await asyncio .gather (* tasks )
230
- return tokenized_text_prompts
231
-
232
- return []
226
+ tasks = (self ._create_prompt (
227
+ prompt_input ,
228
+ config = config ,
229
+ truncate_prompt_tokens = truncate_prompt_tokens ,
230
+ ) for prompt_input in parse_raw_prompts (prompt_or_prompts ))
231
+
232
+ return await asyncio .gather (* tasks )
233
233
234
234
async def render_prompt_and_embeds (
235
235
self ,
236
236
* ,
237
237
prompt_or_prompts : Optional [Union [str , list [str ], list [int ],
238
238
list [list [int ]]]] = None ,
239
239
prompt_embeds : Optional [Union [bytes , list [bytes ]]] = None ,
240
- config : " RenderConfig" ,
240
+ config : RenderConfig ,
241
241
) -> list [Union [EngineTokensPrompt , EngineEmbedsPrompt ]]:
242
242
"""
243
243
Render text/token prompts and/or precomputed embedding prompts. At
244
244
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
245
245
"""
246
- truncate_prompt_tokens = self . _validate_and_normalize_truncate_tokens (
247
- config . truncate_prompt_tokens , config . max_length )
246
+ truncate_prompt_tokens = config . verify_truncate_prompt_tokens (
247
+ self . model_config )
248
248
if truncate_prompt_tokens == 0 :
249
249
return []
250
250
@@ -265,29 +265,6 @@ async def render_prompt_and_embeds(
265
265
266
266
return rendered
267
267
268
- def _validate_and_normalize_truncate_tokens (
269
- self ,
270
- truncate_prompt_tokens : Optional [int ],
271
- max_length : Optional [int ],
272
- ) -> Optional [int ]:
273
- """Validate and normalize truncate_prompt_tokens parameter."""
274
- if truncate_prompt_tokens is None :
275
- return None
276
-
277
- if truncate_prompt_tokens == 0 :
278
- return 0
279
-
280
- if truncate_prompt_tokens < 0 :
281
- truncate_prompt_tokens = self .model_config .max_model_len
282
-
283
- if max_length is not None and truncate_prompt_tokens > max_length : # type: ignore[operator]
284
- raise ValueError (
285
- f"truncate_prompt_tokens ({ truncate_prompt_tokens } ) "
286
- f"cannot be greater than max_length ({ max_length } ). "
287
- f"Please select a smaller truncation size." )
288
-
289
- return truncate_prompt_tokens
290
-
291
268
def _maybe_apply_truncation (
292
269
self , token_ids : list [int ],
293
270
truncate_prompt_tokens : Optional [int ]) -> list [int ]:
@@ -299,7 +276,38 @@ def _maybe_apply_truncation(
299
276
300
277
return token_ids [- truncate_prompt_tokens :]
301
278
302
- async def _tokenize (
279
+ async def _create_prompt (
280
+ self ,
281
+ prompt_input : Union [EngineTextPrompt , EngineTokensPrompt ],
282
+ config : RenderConfig ,
283
+ truncate_prompt_tokens : Optional [int ],
284
+ ) -> EngineTokensPrompt :
285
+ prompt , prompt_token_ids , _ = get_prompt_components (prompt_input )
286
+
287
+ if prompt_token_ids is not None :
288
+ # NOTE: detokenization is needed when echo is enabled,
289
+ # where the input token IDs are decoded back to text.
290
+ return await self ._create_prompt_from_token_ids (
291
+ prompt_token_ids ,
292
+ config .max_length ,
293
+ truncate_prompt_tokens ,
294
+ config .cache_salt ,
295
+ config .needs_detokenization ,
296
+ )
297
+
298
+ if prompt is not None :
299
+ return await self ._create_prompt_from_text (
300
+ prompt ,
301
+ config .max_length ,
302
+ truncate_prompt_tokens ,
303
+ config .add_special_tokens ,
304
+ config .cache_salt ,
305
+ )
306
+
307
+ # TODO: Also handle embeds prompt using this method
308
+ raise NotImplementedError
309
+
310
+ async def _create_prompt_from_text (
303
311
self ,
304
312
text : str ,
305
313
max_length : Optional [int ],
@@ -330,7 +338,7 @@ async def _tokenize(
330
338
return self ._create_tokens_prompt (encoded .input_ids , max_length ,
331
339
cache_salt , text )
332
340
333
- async def _maybe_detokenize (
341
+ async def _create_prompt_from_token_ids (
334
342
self ,
335
343
token_ids : list [int ],
336
344
max_length : Optional [int ],
@@ -343,7 +351,7 @@ async def _maybe_detokenize(
343
351
truncate_prompt_tokens )
344
352
345
353
prompt = None
346
- if needs_detokenization is True :
354
+ if needs_detokenization :
347
355
async_tokenizer = self ._get_async_tokenizer ()
348
356
prompt = await async_tokenizer .decode (token_ids )
349
357
0 commit comments