Skip to content

Commit 0b2d174

Browse files
sfeng33xuebwang-amd
authored andcommitted
Extend renderer with embedding support and integrate completion endpoint (vllm-project#24405)
Signed-off-by: sfeng33 <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent a04e37b commit 0b2d174

File tree

9 files changed

+411
-310
lines changed

9 files changed

+411
-310
lines changed

tests/entrypoints/openai/test_prompt_validation.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import regex as re
1111
import torch
1212

13-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
13+
from vllm.entrypoints.renderer import BaseRenderer
1414

1515
from ...utils import RemoteOpenAIServer
1616

@@ -27,12 +27,16 @@ async def test_empty_prompt():
2727
with RemoteOpenAIServer(model_name, server_args) as remote_server:
2828
client = remote_server.get_async_client()
2929

30-
with pytest.raises(openai.BadRequestError,
31-
match="decoder prompt cannot be empty"):
30+
with pytest.raises(
31+
openai.BadRequestError,
32+
match=
33+
"Either prompt or prompt_embeds must be provided and non-empty."
34+
):
3235
await client.completions.create(model=model_name,
3336
prompt="",
3437
max_tokens=5,
35-
temperature=0.0)
38+
temperature=0.0,
39+
extra_body={"prompt_embeds": []})
3640

3741

3842
@pytest.mark.asyncio
@@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
8387
buffer.seek(0)
8488
encoded_tensor = pybase64.b64encode(buffer.getvalue())
8589

86-
loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
90+
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor)
8791
assert len(loaded_prompt_embeds) == 1
8892
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
8993
assert loaded_tensor.device.type == "cpu"

tests/entrypoints/test_renderer.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import io
45
from dataclasses import dataclass
56
from typing import Optional
67
from unittest.mock import AsyncMock, MagicMock
78

9+
import pybase64
810
import pytest
11+
import torch
912

1013
from vllm.entrypoints.renderer import CompletionRenderer
14+
from vllm.inputs.data import is_embeds_prompt
1115

1216

1317
@dataclass
@@ -178,3 +182,132 @@ async def test_no_tokenizer_for_text(self, mock_model_config):
178182
with pytest.raises(ValueError, match="No tokenizer available"):
179183
await renderer_no_tokenizer.render_prompt(
180184
prompt_or_prompts="Hello world", max_length=100)
185+
186+
@pytest.mark.asyncio
187+
async def test_token_input_with_needs_detokenization(
188+
self, renderer, mock_async_tokenizer):
189+
# When needs_detokenization=True for token inputs, renderer should
190+
# use the async tokenizer to decode and include the original text
191+
# in the returned prompt object.
192+
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
193+
renderer.async_tokenizer_pool[
194+
renderer.tokenizer] = mock_async_tokenizer
195+
196+
tokens = [1, 2, 3, 4]
197+
results = await renderer.render_prompt(
198+
prompt_or_prompts=tokens,
199+
needs_detokenization=True,
200+
)
201+
202+
assert len(results) == 1
203+
assert results[0]["prompt_token_ids"] == tokens
204+
assert results[0]["prompt"] == "decoded text"
205+
mock_async_tokenizer.decode.assert_awaited_once()
206+
207+
208+
class TestRenderEmbedPrompt:
209+
210+
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
211+
"""Helper to create base64-encoded tensor bytes"""
212+
buffer = io.BytesIO()
213+
torch.save(tensor, buffer)
214+
buffer.seek(0)
215+
return pybase64.b64encode(buffer.read())
216+
217+
@pytest.mark.asyncio
218+
async def test_single_prompt_embed(self, renderer):
219+
# Create a test tensor
220+
test_tensor = torch.randn(10, 768, dtype=torch.float32)
221+
embed_bytes = self._create_test_embed_bytes(test_tensor)
222+
223+
results = await renderer.render_prompt_and_embeds(
224+
prompt_embeds=embed_bytes, cache_salt="test_salt")
225+
226+
assert len(results) == 1
227+
assert is_embeds_prompt(results[0])
228+
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
229+
assert results[0]["cache_salt"] == "test_salt"
230+
231+
@pytest.mark.asyncio
232+
async def test_multiple_prompt_embeds(self, renderer):
233+
# Create multiple test tensors
234+
test_tensors = [
235+
torch.randn(8, 512, dtype=torch.float32),
236+
torch.randn(12, 512, dtype=torch.float32),
237+
]
238+
embed_bytes_list = [
239+
self._create_test_embed_bytes(t) for t in test_tensors
240+
]
241+
242+
results = await renderer.render_prompt_and_embeds(
243+
prompt_embeds=embed_bytes_list)
244+
245+
assert len(results) == 2
246+
for i, result in enumerate(results):
247+
assert is_embeds_prompt(result)
248+
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
249+
250+
@pytest.mark.asyncio
251+
async def test_prompt_embed_truncation(self, renderer):
252+
# Create tensor with more tokens than truncation limit
253+
test_tensor = torch.randn(20, 768, dtype=torch.float32)
254+
embed_bytes = self._create_test_embed_bytes(test_tensor)
255+
256+
results = await renderer.render_prompt_and_embeds(
257+
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
258+
259+
assert len(results) == 1
260+
# Should keep last 10 tokens
261+
expected = test_tensor[-10:]
262+
assert torch.allclose(results[0]["prompt_embeds"], expected)
263+
264+
@pytest.mark.asyncio
265+
async def test_prompt_embed_different_dtypes(self, renderer):
266+
# Test different supported dtypes
267+
dtypes = [torch.float32, torch.float16, torch.bfloat16]
268+
269+
for dtype in dtypes:
270+
test_tensor = torch.randn(5, 256, dtype=dtype)
271+
embed_bytes = self._create_test_embed_bytes(test_tensor)
272+
273+
results = await renderer.render_prompt_and_embeds(
274+
prompt_embeds=embed_bytes)
275+
276+
assert len(results) == 1
277+
assert results[0]["prompt_embeds"].dtype == dtype
278+
279+
@pytest.mark.asyncio
280+
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
281+
# Test tensor with batch dimension gets squeezed
282+
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
283+
embed_bytes = self._create_test_embed_bytes(test_tensor)
284+
285+
results = await renderer.render_prompt_and_embeds(
286+
prompt_embeds=embed_bytes)
287+
288+
assert len(results) == 1
289+
# Should be squeezed to 2D
290+
assert results[0]["prompt_embeds"].shape == (10, 768)
291+
292+
@pytest.mark.asyncio
293+
async def test_both_prompts_and_embeds(self, renderer,
294+
mock_async_tokenizer):
295+
# Set up text tokenization
296+
mock_async_tokenizer.return_value = MockTokenizerResult(
297+
[101, 102, 103])
298+
renderer.async_tokenizer_pool[
299+
renderer.tokenizer] = mock_async_tokenizer
300+
301+
# Create embed
302+
test_tensor = torch.randn(5, 256, dtype=torch.float32)
303+
embed_bytes = self._create_test_embed_bytes(test_tensor)
304+
305+
results = await renderer.render_prompt_and_embeds(
306+
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
307+
308+
assert len(results) == 2
309+
# First should be embed prompt
310+
assert is_embeds_prompt(results[0])
311+
# Second should be tokens prompt
312+
assert "prompt_token_ids" in results[1]
313+
assert results[1]["prompt_token_ids"] == [101, 102, 103]

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
686686
async def test_completion_with_empty_prompt_embeds(
687687
client: openai.AsyncOpenAI) -> None:
688688
"""Test completion with empty prompt embeds."""
689-
payload: dict[str, list] = {"prompt_embeds": []}
689+
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
690690
headers: dict[str, str] = {"Content-Type": "application/json"}
691691
# base_url = http://localhost:8000/v1/completions
692692
response = requests.post(f"{client.base_url}completions",

vllm/entrypoints/openai/protocol.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,9 +1270,20 @@ def validate_stream_options(cls, data):
12701270
@model_validator(mode="before")
12711271
@classmethod
12721272
def validate_prompt_and_prompt_embeds(cls, data):
1273-
if data.get("prompt") is None and data.get("prompt_embeds") is None:
1273+
prompt = data.get("prompt")
1274+
prompt_embeds = data.get("prompt_embeds")
1275+
1276+
prompt_is_empty = (prompt is None
1277+
or (isinstance(prompt, str) and prompt == ""))
1278+
embeds_is_empty = (prompt_embeds is None
1279+
or (isinstance(prompt_embeds, list)
1280+
and len(prompt_embeds) == 0))
1281+
1282+
if prompt_is_empty and embeds_is_empty:
12741283
raise ValueError(
1275-
"At least one of `prompt` or `prompt_embeds` must be set.")
1284+
"Either prompt or prompt_embeds must be provided and non-empty."
1285+
)
1286+
12761287
return data
12771288

12781289
@model_validator(mode="before")

vllm/entrypoints/openai/serving_completion.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,8 @@
2626
PromptTokenUsageInfo,
2727
RequestResponseMetadata,
2828
UsageInfo)
29-
from vllm.entrypoints.openai.serving_engine import (
30-
EmbedsPrompt as ServingEngineEmbedsPrompt)
3129
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
32-
TextTokensPrompt,
33-
clamp_prompt_logprobs,
34-
is_text_tokens_prompt)
30+
clamp_prompt_logprobs)
3531
# yapf: enable
3632
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3733
from vllm.entrypoints.utils import get_max_tokens
@@ -132,12 +128,19 @@ async def create_completion(
132128
else:
133129
tokenizer = await self.engine_client.get_tokenizer(lora_request
134130
)
135-
136-
request_prompts, engine_prompts = await self._preprocess_completion(
137-
request,
138-
tokenizer,
139-
request.prompt,
131+
renderer = self._get_renderer(tokenizer)
132+
max_input_tokens_len = self.max_model_len - (request.max_tokens
133+
or 0)
134+
135+
engine_prompts = await renderer.render_prompt_and_embeds(
136+
prompt_or_prompts=request.prompt,
137+
prompt_embeds=request.prompt_embeds,
138+
max_length=max_input_tokens_len,
139+
truncate_prompt_tokens=request.truncate_prompt_tokens,
140140
add_special_tokens=request.add_special_tokens,
141+
cache_salt=request.cache_salt,
142+
needs_detokenization=bool(request.echo
143+
and not request.return_token_ids),
141144
)
142145
except ValueError as e:
143146
logger.exception("Error in preprocessing prompt inputs")
@@ -198,7 +201,7 @@ async def create_completion(
198201

199202
self._log_inputs(
200203
request_id_item,
201-
request_prompts[i],
204+
engine_prompt,
202205
params=sampling_params,
203206
lora_request=lora_request,
204207
)
@@ -249,7 +252,7 @@ async def create_completion(
249252
if stream:
250253
return self.completion_stream_generator(
251254
request,
252-
request_prompts,
255+
engine_prompts,
253256
result_generator,
254257
request_id,
255258
created_time,
@@ -273,11 +276,9 @@ async def create_completion(
273276
# We did not pass it into vLLM engine to avoid being redundant
274277
# with the inputs token IDs
275278
if final_res.prompt is None:
276-
request_prompt = request_prompts[i]
277-
if is_text_tokens_prompt(request_prompt):
278-
final_res.prompt = request_prompt["prompt"]
279-
else:
280-
final_res.prompt = None
279+
engine_prompt = engine_prompts[i]
280+
final_res.prompt = None if is_embeds_prompt(
281+
engine_prompt) else engine_prompt.get("prompt")
281282

282283
final_res_batch_checked = cast(list[RequestOutput],
283284
final_res_batch)
@@ -313,8 +314,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
313314
async def completion_stream_generator(
314315
self,
315316
request: CompletionRequest,
316-
request_prompts: list[Union[TextTokensPrompt,
317-
ServingEngineEmbedsPrompt]],
317+
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
318318
result_generator: AsyncIterator[tuple[int, RequestOutput]],
319319
request_id: str,
320320
created_time: int,
@@ -350,14 +350,11 @@ async def completion_stream_generator(
350350
num_cached_tokens = res.num_cached_tokens
351351
first_iteration = False
352352

353-
if res.prompt is not None:
354-
prompt_text = res.prompt
355-
else:
356-
request_prompt = request_prompts[prompt_idx]
357-
if is_text_tokens_prompt(request_prompt):
358-
prompt_text = request_prompt["prompt"]
359-
else:
360-
prompt_text = None
353+
prompt_text = res.prompt
354+
if prompt_text is None:
355+
engine_prompt = engine_prompts[prompt_idx]
356+
prompt_text = None if is_embeds_prompt(
357+
engine_prompt) else engine_prompt.get("prompt")
361358

362359
# Prompt details are excluded from later streamed outputs
363360
if prompt_token_ids is not None:
@@ -378,6 +375,8 @@ async def completion_stream_generator(
378375
assert request.max_tokens is not None
379376
if request.echo and not has_echoed[i]:
380377
assert prompt_token_ids is not None
378+
if request.return_token_ids:
379+
prompt_text = ""
381380
assert prompt_text is not None
382381
if request.max_tokens == 0:
383382
# only return the prompt
@@ -525,6 +524,8 @@ def request_output_to_completion_response(
525524
for output in final_res.outputs:
526525
assert request.max_tokens is not None
527526
if request.echo:
527+
if request.return_token_ids:
528+
prompt_text = ""
528529
assert prompt_text is not None
529530
if request.max_tokens == 0:
530531
token_ids = prompt_token_ids

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
TextTokensPrompt)
2929
# yapf: enable
3030
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
31-
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
3231
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
3332
from vllm.logger import init_logger
3433
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
@@ -290,7 +289,7 @@ def _is_text_tokens_prompt(self, prompt) -> bool:
290289
async def _create_single_prompt_generator(
291290
self,
292291
ctx: EmbeddingServeContext,
293-
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
292+
engine_prompt: EngineTokensPrompt,
294293
pooling_params: PoolingParams,
295294
trace_headers: Optional[Mapping[str, str]],
296295
prompt_index: int,
@@ -303,12 +302,6 @@ async def _create_single_prompt_generator(
303302
params=pooling_params,
304303
lora_request=ctx.lora_request)
305304

306-
# Mypy has an existing bug related to inferring the variance
307-
# of TypedDicts with `builtins.enumerate`:
308-
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
309-
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
310-
engine_prompt)
311-
312305
# Return the original generator without wrapping
313306
return self.engine_client.encode(
314307
engine_prompt,
@@ -375,12 +368,8 @@ async def _prepare_generators(
375368
continue
376369

377370
# Normal processing for short prompts or non-token prompts
378-
# Cast engine_prompt to the expected type for mypy
379-
engine_prompt_typed = cast(
380-
Union[EngineTokensPrompt, EngineEmbedsPrompt],
381-
engine_prompt)
382371
generator = await self._create_single_prompt_generator(
383-
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
372+
ctx, engine_prompt, pooling_params, trace_headers, i)
384373
generators.append(generator)
385374

386375
from vllm.utils import merge_async_iterators

0 commit comments

Comments
 (0)