2828import json
2929from abc import abstractmethod
3030from io import BytesIO
31+ from typing import Callable
3132
3233import numpy as np
3334import triton_python_backend_utils as pb_utils
3435from PIL import Image
3536from vllm .inputs .data import TokensPrompt
3637from vllm .lora .request import LoRARequest
38+ from vllm .outputs import (
39+ EmbeddingOutput ,
40+ EmbeddingRequestOutput ,
41+ PoolingRequestOutput ,
42+ RequestOutput ,
43+ )
3744from vllm .pooling_params import PoolingParams
3845from vllm .utils import random_uuid
3946
4047from utils .vllm_backend_utils import TritonSamplingParams
4148
4249
4350class RequestBase :
44- def __init__ (self , request , executor_callback , output_dtype ):
51+ def __init__ (
52+ self , request , executor_callback : Callable , output_dtype : np .dtype , logger
53+ ):
4554 self .request = request
4655 self .executor_callback = executor_callback
4756 self .output_dtype = output_dtype
57+ self .logger = logger
4858 self .id = random_uuid ()
4959 self .stream = False
5060 self .prepend_input = False
@@ -58,13 +68,15 @@ def execute(self):
5868 raise NotImplementedError
5969
6070 @abstractmethod
61- def create_response (self , * args , ** kwargs ):
71+ def create_response (self , request_output , * args , ** kwargs ):
6272 raise NotImplementedError
6373
6474
6575class GenerateRequest (RequestBase ):
66- def __init__ (self , request , executor_callback , output_dtype ):
67- super ().__init__ (request , executor_callback , output_dtype )
76+ def __init__ (
77+ self , request , executor_callback : Callable , output_dtype : np .dtype , logger
78+ ):
79+ super ().__init__ (request , executor_callback , output_dtype , logger )
6880
6981 def _get_input_tensors (self ):
7082 # prompt
@@ -166,7 +178,12 @@ async def execute(self):
166178 async for response in response_iterator :
167179 yield response
168180
169- def create_response (self , request_output_state , request_output , prepend_input ):
181+ def create_response (
182+ self ,
183+ request_output : RequestOutput ,
184+ request_output_state : dict ,
185+ prepend_input : bool ,
186+ ):
170187 output_tensors = []
171188
172189 # text_output
@@ -278,8 +295,10 @@ def create_response(self, request_output_state, request_output, prepend_input):
278295
279296
280297class EmbedRequest (RequestBase ):
281- def __init__ (self , request , executor_callback , output_dtype ):
282- super ().__init__ (request , executor_callback , output_dtype )
298+ def __init__ (
299+ self , request , executor_callback : Callable , output_dtype : np .dtype , logger
300+ ):
301+ super ().__init__ (request , executor_callback , output_dtype , logger )
283302
284303 def _get_input_tensors (self ):
285304 embedding_request = pb_utils .get_input_tensor_by_name (
@@ -338,32 +357,16 @@ def _to_pooling_params(self, embedding_request: dict):
338357 pooling_params = PoolingParams (dimensions = dims , task = "embed" )
339358 return pooling_params
340359
341- def create_response (self , request_output ):
360+ def create_response (self , request_output : PoolingRequestOutput [ EmbeddingOutput ] ):
342361 output_tensors = []
362+ request_output = EmbeddingRequestOutput .from_base (request_output )
343363
344- # Extract embedding vector from output
345- # PoolingRequestOutput.outputs is a PoolingOutput with .data (torch.Tensor)
346- pooling_data = request_output .outputs .data
347-
348- # Convert torch tensor to numpy array then to list for JSON serialization
349- if hasattr (pooling_data , "cpu" ):
350- # It's a torch tensor - move to CPU and convert to numpy
351- embedding_array = pooling_data .cpu ().numpy ()
352- else :
353- # Already numpy or list
354- embedding_array = np .array (pooling_data , dtype = np .float32 )
355-
356- # Create response tensor - for embeddings, we use text_output to return the vector
357- # (This is a simplification - you may want to define a proper embedding output tensor)
358- embedding_list = (
359- embedding_array .tolist ()
360- if hasattr (embedding_array , "tolist" )
361- else list (embedding_array )
362- )
363- embedding_str = json .dumps (embedding_list )
364+ # Extract embedding list from output
365+ embedding : list [float ] = request_output .outputs .embedding
364366 output_tensors .append (
365367 pb_utils .Tensor (
366- "text_output" , np .asarray ([embedding_str ], dtype = self .output_dtype )
368+ "text_output" ,
369+ np .asarray ([json .dumps (embedding )], dtype = self .output_dtype ),
367370 )
368371 )
369372
0 commit comments