diff --git a/src/model.py b/src/model.py index 4145b71b..067b2161 100644 --- a/src/model.py +++ b/src/model.py @@ -25,27 +25,22 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio -import base64 import gc import json import os import queue import threading -from io import BytesIO from typing import Dict, List import numpy as np import triton_python_backend_utils as pb_utils -from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args, ) -from vllm.lora.request import LoRARequest -from vllm.utils import random_uuid from utils.metrics import VllmStatLoggerFactory -from utils.vllm_backend_utils import TritonSamplingParams +from utils.request import EmbedRequest, GenerateRequest _VLLM_ENGINE_ARGS_FILENAME = "model.json" _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" @@ -73,6 +68,7 @@ def auto_complete_config(cls, auto_complete_model_config): def _auto_complete_inputs_and_outputs(auto_complete_model_config): # Inputs expected by the backend. inputs = [ + # TODO: Support array input {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { "name": "image", @@ -128,6 +124,14 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "dims": [1], "optional": True, }, + # Tentative input reserved for embedding requests in OpenAI-compatible frontend. Subject to change in the future. + # WARN: Triton client should never set this input. It is reserved for embedding requests in OpenAI-compatible frontend. + { + "name": "embedding_request", + "data_type": "TYPE_STRING", + "dims": [1], + "optional": True, + }, ] # Outputs expected by the backend. outputs = [ @@ -246,6 +250,11 @@ def _init_engine(self): self._event_thread = None raise e + # Get supported tasks from the engine running in another thread + self.supported_tasks = asyncio.run_coroutine_threadsafe( + self._llm_engine.get_supported_tasks(), self._event_loop + ).result() + async def _run_llm_engine(self): # Counter to keep track of ongoing request counts. self._ongoing_request_count = 0 @@ -395,6 +404,35 @@ def _response_loop(self): if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: self._ongoing_request_count -= 1 + def respond_error(self, request, error_message, triton_error): + output_tensor = pb_utils.Tensor( + "text_output", + np.asarray([error_message], dtype=self.output_dtype), + ) + response = pb_utils.InferenceResponse( + output_tensors=[output_tensor], error=triton_error + ) + response_sender = request.get_response_sender() + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + + def _validate_request_task_name(self, request): + embedding_request = pb_utils.get_input_tensor_by_name( + request, "embedding_request" + ) + if embedding_request is None: + request_task_name = "generate" + else: + request_task_name = "embed" + + if request_task_name not in self.supported_tasks: + raise ValueError( + f"Model {self.args['model_name']} does not support '{request_task_name}' request" + ) + + return request_task_name + def execute(self, requests): if self._enable_health_check and not self._check_health(requests): return None @@ -404,11 +442,11 @@ def execute(self, requests): assert ( self._llm_engine_shutdown_event.is_set() is False ), "Cannot create tasks after shutdown has been requested" - coro = self._generate(request) + coro = self._infer(request) asyncio.run_coroutine_threadsafe(coro, self._event_loop) return None - async def _generate(self, request): + async def _infer(self, request): response_sender = request.get_response_sender() response_state = { "response_sender": response_sender, @@ -418,27 +456,21 @@ async def _generate(self, request): self._ongoing_request_count += 1 decrement_ongoing_request_count = True try: - request_id = random_uuid() - ( - prompt, - stream, - prepend_input, - parameters, - additional_outputs, - ) = self._get_input_tensors(request) - - sampling_params = TritonSamplingParams.from_dict(parameters, self.logger) - lora_name = sampling_params.lora_name - lora_request = None - if lora_name is not None: - lora_id = str(self.supported_loras.index(lora_name) + 1) - lora_int_id = int(lora_id) - lora_local_path = self.lora_repository[lora_name] - lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - - response_iterator = self._llm_engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request - ) + request_task_name = self._validate_request_task_name(request) + if request_task_name == "generate": + request = GenerateRequest( + request, self._llm_engine.generate, self.output_dtype, self.logger + ) + elif request_task_name == "embed": + request = EmbedRequest( + request, self._llm_engine.encode, self.output_dtype, self.logger + ) + else: + raise ValueError( + f"VLLM backend does not support '{request_task_name}' request" + ) + + response_iterator = request.execute() request_output_state = {} async for request_output in response_iterator: @@ -446,14 +478,14 @@ async def _generate(self, request): # the response state if streaming. If not streaming, cancellation state # needs to be checked here. is_cancelled = response_state["is_cancelled"] - if not stream: + if not request.stream: is_cancelled = response_sender.is_cancelled() if is_cancelled: self.logger.log_info("[vllm] Cancelling the request") - await self._llm_engine.abort(request_id) + await self._llm_engine.abort(request.id) self.logger.log_info("[vllm] Successfully cancelled the request") - if stream: + if request.stream: # Add cancelled final response to response loop. response_state["last_response_generated"] = True response = pb_utils.InferenceResponse( @@ -471,12 +503,11 @@ async def _generate(self, request): break # Send each response if streaming. - if stream: - response = self._create_response( - request_output_state, + if request.stream: + response = request.create_response( request_output, + request_output_state, prepend_input=False, - additional_outputs=additional_outputs, ) flags = 0 if request_output.finished: @@ -486,15 +517,19 @@ async def _generate(self, request): self._response_queue.put_nowait((response_state, response, flags)) # Send the last response which contains all the outputs if not streaming. - if not stream: - response_sender.send( - self._create_response( + if not request.stream: + if request_task_name == "generate": + response = request.create_response( + request_output=request_output, request_output_state={}, + prepend_input=request.prepend_input, + ) + else: + response = request.create_response( request_output=request_output, - prepend_input=prepend_input, - additional_outputs=additional_outputs, - ), - flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL ) except Exception as e: @@ -515,191 +550,6 @@ async def _generate(self, request): if decrement_ongoing_request_count: self._ongoing_request_count -= 1 - def _get_input_tensors(self, request): - # prompt - prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0] - if isinstance(prompt, bytes): - prompt = prompt.decode("utf-8") - - # image - images = pb_utils.get_input_tensor_by_name(request, "image") - if images: - images_vllm = [] - for image_np in images.as_numpy(): - image_b = base64.b64decode(image_np.decode("utf-8")) - image_rgb = Image.open(BytesIO(image_b)).convert("RGB") - images_vllm.append(image_rgb) - if len(images_vllm) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": {"image": images_vllm}, - } - - # stream - stream = pb_utils.get_input_tensor_by_name(request, "stream") - if stream: - stream = stream.as_numpy()[0] - else: - stream = False - - # prepend_input / exclude_input_in_output - prepend_input = pb_utils.get_input_tensor_by_name( - request, "exclude_input_in_output" - ) - if prepend_input: - # When `exclude_input_in_output` is False, we want to prepend input prompt - # to output, thus prepend_input should be True, and vice versa. - prepend_input = not prepend_input.as_numpy()[0] - elif prepend_input is None and stream: - prepend_input = False - else: - prepend_input = True - if prepend_input and stream: - raise ValueError( - "When streaming, `exclude_input_in_output` = False is not allowed." - ) - - # parameters / sampling_parameters - # An alternative mechanism to receive serialized parameters as an input - # tensor, because request parameters are not yet supported via BLS. - sampling_parameters = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if sampling_parameters: - parameters = sampling_parameters.as_numpy()[0].decode("utf-8") - else: - parameters = request.parameters() - - # additional outputs - additional_outputs = { - "return_finish_reason": None, - "return_cumulative_logprob": None, - "return_logprobs": None, - "return_num_input_tokens": None, - "return_num_output_tokens": None, - } - for tensor_name in additional_outputs.keys(): - tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) - if tensor: - tensor = bool(tensor.as_numpy()[0]) - else: - tensor = False - additional_outputs[tensor_name] = tensor - - return prompt, stream, prepend_input, parameters, additional_outputs - - def _create_response( - self, request_output_state, request_output, prepend_input, additional_outputs - ): - output_tensors = [] - - # text_output - prepend_prompt = "" - if "prev_lens_text_output" not in request_output_state: - # this is the first response - if prepend_input: - prepend_prompt = request_output.prompt - request_output_state["prev_lens_text_output"] = [0] * len( - request_output.outputs - ) - prev_lens = request_output_state["prev_lens_text_output"] - text_output = [ - (prepend_prompt + output.text[prev_len:]).encode("utf-8") - for output, prev_len in zip(request_output.outputs, prev_lens) - ] - request_output_state["prev_lens_text_output"] = [ - len(output.text) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "text_output", np.asarray(text_output, dtype=self.output_dtype) - ) - ) - - # finish_reason - if additional_outputs["return_finish_reason"]: - finish_reason = [ - str(output.finish_reason) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "finish_reason", np.asarray(finish_reason, dtype=np.object_) - ) - ) - - # cumulative_logprob - if additional_outputs["return_cumulative_logprob"]: - cumulative_logprob = [ - output.cumulative_logprob for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "cumulative_logprob", - np.asarray(cumulative_logprob, dtype=np.float32), - ) - ) - - # logprobs - # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58 - if additional_outputs["return_logprobs"]: - if "prev_lens_logprobs" not in request_output_state: - request_output_state["prev_lens_logprobs"] = [0] * len( - request_output.outputs - ) - logprobs = [] - for i in range(len(request_output.outputs)): - output = request_output.outputs[i] - if output.logprobs is None: - logprobs.append("null".encode("utf-8")) - continue - prev_len = request_output_state["prev_lens_logprobs"][i] - request_output_state["prev_lens_logprobs"][i] = len(output.logprobs) - logprobs_py = [] - for logprob_d_vllm in output.logprobs[prev_len:]: - logprob_d_py = {} - for token_id, logprob_vllm in logprob_d_vllm.items(): - logprob_d_py[token_id] = { - "logprob": logprob_vllm.logprob, - "rank": logprob_vllm.rank, - "decoded_token": logprob_vllm.decoded_token, - } - logprobs_py.append(logprob_d_py) - logprobs.append(json.dumps(logprobs_py).encode("utf-8")) - output_tensors.append( - pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_)) - ) - - # num_input_tokens - if additional_outputs["return_num_input_tokens"]: - num_input_tokens = len(request_output.prompt_token_ids) - output_tensors.append( - pb_utils.Tensor( - "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) - ) - ) - - # num_output_tokens - if additional_outputs["return_num_output_tokens"]: - if "prev_lens_num_output_tokens" not in request_output_state: - request_output_state["prev_lens_num_output_tokens"] = [0] * len( - request_output.outputs - ) - prev_lens = request_output_state["prev_lens_num_output_tokens"] - num_output_tokens = [ - (len(output.token_ids) - prev_len) - for output, prev_len in zip(request_output.outputs, prev_lens) - ] - request_output_state["prev_lens_num_output_tokens"] = [ - len(output.token_ids) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) - ) - ) - - return pb_utils.InferenceResponse(output_tensors=output_tensors) - def _verify_loras(self, request): # We will check if the requested lora exists here, if not we will send a # response with `LoRA not found` information. In this way we may avoid @@ -729,17 +579,7 @@ def _verify_loras(self, request): self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") if lora_error is not None: - output_tensor = pb_utils.Tensor( - "text_output", - np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), - ) - response = pb_utils.InferenceResponse( - output_tensors=[output_tensor], error=lora_error - ) - response_sender = request.get_response_sender() - response_sender.send( - response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - ) + self.respond_error(request, lora_error.message, lora_error) else: verified_request = request return verified_request diff --git a/src/utils/request.py b/src/utils/request.py new file mode 100644 index 00000000..ff9b6c6b --- /dev/null +++ b/src/utils/request.py @@ -0,0 +1,388 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import json +from abc import abstractmethod +from io import BytesIO +from typing import Callable + +import numpy as np +import triton_python_backend_utils as pb_utils +from PIL import Image +from vllm.inputs.data import TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import ( + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingRequestOutput, + RequestOutput, +) +from vllm.pooling_params import PoolingParams +from vllm.utils import random_uuid + +from utils.vllm_backend_utils import TritonSamplingParams + + +class RequestBase: + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): + self.request = request + self.executor_callback = executor_callback + self.output_dtype = output_dtype + self.logger = logger + self.id = random_uuid() + self.stream = False + self.prepend_input = False + + @abstractmethod + def _get_input_tensors(self): + raise NotImplementedError + + @abstractmethod + def execute(self): + raise NotImplementedError + + @abstractmethod + def create_response(self, request_output, *args, **kwargs): + raise NotImplementedError + + +class GenerateRequest(RequestBase): + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): + super().__init__(request, executor_callback, output_dtype, logger) + + def _get_input_tensors(self): + # prompt + prompt = pb_utils.get_input_tensor_by_name( + self.request, "text_input" + ).as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # image + images = pb_utils.get_input_tensor_by_name(self.request, "image") + if images: + images_vllm = [] + for image_np in images.as_numpy(): + image_b = base64.b64decode(image_np.decode("utf-8")) + image_rgb = Image.open(BytesIO(image_b)).convert("RGB") + images_vllm.append(image_rgb) + if len(images_vllm) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": {"image": images_vllm}, + } + + # stream + stream = pb_utils.get_input_tensor_by_name(self.request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + self.request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + prepend_input = True + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # parameters / sampling_parameters + # An alternative mechanism to receive serialized parameters as an input + # tensor, because request parameters are not yet supported via BLS. + sampling_parameters = pb_utils.get_input_tensor_by_name( + self.request, "sampling_parameters" + ) + if sampling_parameters: + parameters = sampling_parameters.as_numpy()[0].decode("utf-8") + else: + parameters = self.request.parameters() + + # additional outputs + additional_outputs = { + "return_finish_reason": None, + "return_cumulative_logprob": None, + "return_logprobs": None, + "return_num_input_tokens": None, + "return_num_output_tokens": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, parameters, additional_outputs + + async def execute(self): + ( + prompt, + self.stream, + self.prepend_input, + parameters, + self.additional_outputs, + ) = self._get_input_tensors() + + sampling_params = TritonSamplingParams.from_dict(parameters, self.logger) + lora_name = sampling_params.lora_name + lora_request = None + if lora_name is not None: + lora_id = str(self.supported_loras.index(lora_name) + 1) + lora_int_id = int(lora_id) + lora_local_path = self.lora_repository[lora_name] + lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) + + response_iterator = self.executor_callback( + prompt, sampling_params, self.id, lora_request=lora_request + ) + + async for response in response_iterator: + yield response + + def create_response( + self, + request_output: RequestOutput, + request_output_state: dict, + prepend_input: bool, + ): + output_tensors = [] + + # text_output + prepend_prompt = "" + if "prev_lens_text_output" not in request_output_state: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + request_output_state["prev_lens_text_output"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_text_output"] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_text_output"] = [ + len(output.text) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) + ) + + # finish_reason + if self.additional_outputs["return_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) + + # cumulative_logprob + if self.additional_outputs["return_cumulative_logprob"]: + cumulative_logprob = [ + output.cumulative_logprob for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "cumulative_logprob", + np.asarray(cumulative_logprob, dtype=np.float32), + ) + ) + + # logprobs + # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58 + if self.additional_outputs["return_logprobs"]: + if "prev_lens_logprobs" not in request_output_state: + request_output_state["prev_lens_logprobs"] = [0] * len( + request_output.outputs + ) + logprobs = [] + for i in range(len(request_output.outputs)): + output = request_output.outputs[i] + if output.logprobs is None: + logprobs.append("null".encode("utf-8")) + continue + prev_len = request_output_state["prev_lens_logprobs"][i] + request_output_state["prev_lens_logprobs"][i] = len(output.logprobs) + logprobs_py = [] + for logprob_d_vllm in output.logprobs[prev_len:]: + logprob_d_py = {} + for token_id, logprob_vllm in logprob_d_vllm.items(): + logprob_d_py[token_id] = { + "logprob": logprob_vllm.logprob, + "rank": logprob_vllm.rank, + "decoded_token": logprob_vllm.decoded_token, + } + logprobs_py.append(logprob_d_py) + logprobs.append(json.dumps(logprobs_py).encode("utf-8")) + output_tensors.append( + pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_)) + ) + + # num_input_tokens + if self.additional_outputs["return_num_input_tokens"]: + num_input_tokens = len(request_output.prompt_token_ids) + output_tensors.append( + pb_utils.Tensor( + "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) + ) + ) + + # num_output_tokens + if self.additional_outputs["return_num_output_tokens"]: + if "prev_lens_num_output_tokens" not in request_output_state: + request_output_state["prev_lens_num_output_tokens"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_num_output_tokens"] + num_output_tokens = [ + (len(output.token_ids) - prev_len) + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_num_output_tokens"] = [ + len(output.token_ids) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) + + +class EmbedRequest(RequestBase): + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): + super().__init__(request, executor_callback, output_dtype, logger) + + def _get_input_tensors(self): + embedding_request = pb_utils.get_input_tensor_by_name( + self.request, "embedding_request" + ).as_numpy()[0] + embedding_request = json.loads(embedding_request.decode("utf-8")) + # prompt + prompt = embedding_request["input"] + if isinstance(prompt, str): + pass # do nothing + elif ( + isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int) + ): + # Single list of token IDs + prompt = TokensPrompt(prompt_token_ids=prompt) + + # pooling_params + pooling_params = self._to_pooling_params(embedding_request) + + # additional outputs + additional_outputs = { + "return_num_input_tokens": None, + "return_num_output_tokens": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, pooling_params, additional_outputs + + async def execute(self): + ( + prompt, + pooling_params, + self.additional_outputs, + ) = self._get_input_tensors() + + # Create PoolingParams for embeddings + response_iterator = self.executor_callback(prompt, pooling_params, self.id) + + # Yield each response from the async iterator + async for response in response_iterator: + yield response + + def _to_pooling_params(self, embedding_request: dict): + pooling_params_dict = embedding_request.get("pooling_params", {}) + + pooling_params = PoolingParams(task="embed") + dims = None + if "dimensions" in pooling_params_dict: + dims = pooling_params_dict["dimensions"][0] + pooling_params = PoolingParams(dimensions=dims, task="embed") + return pooling_params + + def create_response(self, request_output: PoolingRequestOutput[EmbeddingOutput]): + output_tensors = [] + request_output = EmbeddingRequestOutput.from_base(request_output) + + # Extract embedding list from output + embedding: list[float] = request_output.outputs.embedding + output_tensors.append( + pb_utils.Tensor( + "text_output", + np.asarray([json.dumps(embedding)], dtype=self.output_dtype), + ) + ) + + # num_input_tokens + if self.additional_outputs["return_num_input_tokens"]: + num_input_tokens = len(request_output.prompt_token_ids) + output_tensors.append( + pb_utils.Tensor( + "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) + ) + ) + + # For embeddings, num_output_tokens is 0 (no generation happened) + if self.additional_outputs["return_num_output_tokens"]: + output_tensors.append( + pb_utils.Tensor("num_output_tokens", np.asarray(0, dtype=np.uint32)) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors)