Skip to content

Commit 943ee5f

Browse files
committed
Address comment and rebase to r25.10 (V1 API)
1 parent 961a7c3 commit 943ee5f

File tree

2 files changed

+42
-36
lines changed

2 files changed

+42
-36
lines changed

src/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141

4242
from utils.metrics import VllmStatLoggerFactory
43-
from utils.vllm_backend_utils import TritonSamplingParams
43+
from utils.request import EmbedRequest, GenerateRequest
4444

4545
_VLLM_ENGINE_ARGS_FILENAME = "model.json"
4646
_MULTI_LORA_ARGS_FILENAME = "multi_lora.json"
@@ -249,6 +249,11 @@ def _init_engine(self):
249249
self._event_thread = None
250250
raise e
251251

252+
# Get supported tasks from the engine running in another thread
253+
self.supported_tasks = asyncio.run_coroutine_threadsafe(
254+
self._llm_engine.get_supported_tasks(), self._event_loop
255+
).result()
256+
252257
async def _run_llm_engine(self):
253258
# Counter to keep track of ongoing request counts.
254259
self._ongoing_request_count = 0
@@ -453,11 +458,11 @@ async def _infer(self, request):
453458
request_task_name = self._validate_request_task_name(request)
454459
if request_task_name == "generate":
455460
request = GenerateRequest(
456-
request, self._llm_engine.generate, self.output_dtype
461+
request, self._llm_engine.generate, self.output_dtype, self.logger
457462
)
458463
elif request_task_name == "embed":
459464
request = EmbedRequest(
460-
request, self._llm_engine.encode, self.output_dtype
465+
request, self._llm_engine.encode, self.output_dtype, self.logger
461466
)
462467
else:
463468
raise ValueError(
@@ -499,10 +504,9 @@ async def _infer(self, request):
499504
# Send each response if streaming.
500505
if request.stream:
501506
response = request.create_response(
502-
request_output_state,
503507
request_output,
508+
request_output_state,
504509
prepend_input=False,
505-
additional_outputs=request.additional_outputs,
506510
)
507511
flags = 0
508512
if request_output.finished:
@@ -515,10 +519,9 @@ async def _infer(self, request):
515519
if not request.stream:
516520
if request_task_name == "generate":
517521
response = request.create_response(
518-
request_output_state={},
519522
request_output=request_output,
523+
request_output_state={},
520524
prepend_input=request.prepend_input,
521-
additional_outputs=request.additional_outputs,
522525
)
523526
else:
524527
response = request.create_response(

src/utils/request.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,33 @@
2828
import json
2929
from abc import abstractmethod
3030
from io import BytesIO
31+
from typing import Callable
3132

3233
import numpy as np
3334
import triton_python_backend_utils as pb_utils
3435
from PIL import Image
3536
from vllm.inputs.data import TokensPrompt
3637
from vllm.lora.request import LoRARequest
38+
from vllm.outputs import (
39+
EmbeddingOutput,
40+
EmbeddingRequestOutput,
41+
PoolingRequestOutput,
42+
RequestOutput,
43+
)
3744
from vllm.pooling_params import PoolingParams
3845
from vllm.utils import random_uuid
3946

4047
from utils.vllm_backend_utils import TritonSamplingParams
4148

4249

4350
class RequestBase:
44-
def __init__(self, request, executor_callback, output_dtype):
51+
def __init__(
52+
self, request, executor_callback: Callable, output_dtype: np.dtype, logger
53+
):
4554
self.request = request
4655
self.executor_callback = executor_callback
4756
self.output_dtype = output_dtype
57+
self.logger = logger
4858
self.id = random_uuid()
4959
self.stream = False
5060
self.prepend_input = False
@@ -58,13 +68,15 @@ def execute(self):
5868
raise NotImplementedError
5969

6070
@abstractmethod
61-
def create_response(self, *args, **kwargs):
71+
def create_response(self, request_output, *args, **kwargs):
6272
raise NotImplementedError
6373

6474

6575
class GenerateRequest(RequestBase):
66-
def __init__(self, request, executor_callback, output_dtype):
67-
super().__init__(request, executor_callback, output_dtype)
76+
def __init__(
77+
self, request, executor_callback: Callable, output_dtype: np.dtype, logger
78+
):
79+
super().__init__(request, executor_callback, output_dtype, logger)
6880

6981
def _get_input_tensors(self):
7082
# prompt
@@ -166,7 +178,12 @@ async def execute(self):
166178
async for response in response_iterator:
167179
yield response
168180

169-
def create_response(self, request_output_state, request_output, prepend_input):
181+
def create_response(
182+
self,
183+
request_output: RequestOutput,
184+
request_output_state: dict,
185+
prepend_input: bool,
186+
):
170187
output_tensors = []
171188

172189
# text_output
@@ -278,8 +295,10 @@ def create_response(self, request_output_state, request_output, prepend_input):
278295

279296

280297
class EmbedRequest(RequestBase):
281-
def __init__(self, request, executor_callback, output_dtype):
282-
super().__init__(request, executor_callback, output_dtype)
298+
def __init__(
299+
self, request, executor_callback: Callable, output_dtype: np.dtype, logger
300+
):
301+
super().__init__(request, executor_callback, output_dtype, logger)
283302

284303
def _get_input_tensors(self):
285304
embedding_request = pb_utils.get_input_tensor_by_name(
@@ -338,32 +357,16 @@ def _to_pooling_params(self, embedding_request: dict):
338357
pooling_params = PoolingParams(dimensions=dims, task="embed")
339358
return pooling_params
340359

341-
def create_response(self, request_output):
360+
def create_response(self, request_output: PoolingRequestOutput[EmbeddingOutput]):
342361
output_tensors = []
362+
request_output = EmbeddingRequestOutput.from_base(request_output)
343363

344-
# Extract embedding vector from output
345-
# PoolingRequestOutput.outputs is a PoolingOutput with .data (torch.Tensor)
346-
pooling_data = request_output.outputs.data
347-
348-
# Convert torch tensor to numpy array then to list for JSON serialization
349-
if hasattr(pooling_data, "cpu"):
350-
# It's a torch tensor - move to CPU and convert to numpy
351-
embedding_array = pooling_data.cpu().numpy()
352-
else:
353-
# Already numpy or list
354-
embedding_array = np.array(pooling_data, dtype=np.float32)
355-
356-
# Create response tensor - for embeddings, we use text_output to return the vector
357-
# (This is a simplification - you may want to define a proper embedding output tensor)
358-
embedding_list = (
359-
embedding_array.tolist()
360-
if hasattr(embedding_array, "tolist")
361-
else list(embedding_array)
362-
)
363-
embedding_str = json.dumps(embedding_list)
364+
# Extract embedding list from output
365+
embedding: list[float] = request_output.outputs.embedding
364366
output_tensors.append(
365367
pb_utils.Tensor(
366-
"text_output", np.asarray([embedding_str], dtype=self.output_dtype)
368+
"text_output",
369+
np.asarray([json.dumps(embedding)], dtype=self.output_dtype),
367370
)
368371
)
369372

0 commit comments

Comments
 (0)