From a5c92e7f0e88f24b387522a3079e2aa52d6d251d Mon Sep 17 00:00:00 2001 From: Piotr Marcinkiewicz Date: Tue, 11 Jun 2024 09:32:30 +0200 Subject: [PATCH 1/4] Add client from PyTriton with simple coupled request --- src/python/library/tritonclient/_client.py | 25 +- src/python/library/tritonclient/hl/README.md | 48 + .../library/tritonclient/hl/__init__.py | 18 + src/python/library/tritonclient/hl/client.py | 834 ++++++++++++++++++ src/python/library/tritonclient/hl/common.py | 93 ++ .../library/tritonclient/hl/constants.py | 31 + .../library/tritonclient/hl/exceptions.py | 92 ++ .../library/tritonclient/hl/lw/__init__.py | 13 + src/python/library/tritonclient/hl/lw/grpc.py | 20 + src/python/library/tritonclient/hl/lw/http.py | 139 +++ .../library/tritonclient/hl/lw/infer_input.py | 355 ++++++++ .../library/tritonclient/hl/lw/utils.py | 16 + .../tritonclient/hl/model_config/__init__.py | 17 + .../tritonclient/hl/model_config/common.py | 93 ++ .../hl/model_config/model_config.py | 43 + .../tritonclient/hl/model_config/parser.py | 258 ++++++ .../tritonclient/hl/model_config/tensor.py | 57 ++ .../hl/model_config/triton_model_config.py | 68 ++ .../tritonclient/hl/triton_model_config.py | 68 ++ src/python/library/tritonclient/hl/utils.py | 384 ++++++++ .../library/tritonclient/hl/warnings.py | 26 + 21 files changed, 2696 insertions(+), 2 deletions(-) create mode 100644 src/python/library/tritonclient/hl/README.md create mode 100644 src/python/library/tritonclient/hl/__init__.py create mode 100755 src/python/library/tritonclient/hl/client.py create mode 100644 src/python/library/tritonclient/hl/common.py create mode 100644 src/python/library/tritonclient/hl/constants.py create mode 100644 src/python/library/tritonclient/hl/exceptions.py create mode 100755 src/python/library/tritonclient/hl/lw/__init__.py create mode 100755 src/python/library/tritonclient/hl/lw/grpc.py create mode 100755 src/python/library/tritonclient/hl/lw/http.py create mode 100755 src/python/library/tritonclient/hl/lw/infer_input.py create mode 100755 src/python/library/tritonclient/hl/lw/utils.py create mode 100644 src/python/library/tritonclient/hl/model_config/__init__.py create mode 100644 src/python/library/tritonclient/hl/model_config/common.py create mode 100644 src/python/library/tritonclient/hl/model_config/model_config.py create mode 100644 src/python/library/tritonclient/hl/model_config/parser.py create mode 100644 src/python/library/tritonclient/hl/model_config/tensor.py create mode 100644 src/python/library/tritonclient/hl/model_config/triton_model_config.py create mode 100644 src/python/library/tritonclient/hl/triton_model_config.py create mode 100644 src/python/library/tritonclient/hl/utils.py create mode 100644 src/python/library/tritonclient/hl/warnings.py diff --git a/src/python/library/tritonclient/_client.py b/src/python/library/tritonclient/_client.py index cdf8fa211..2b5ba612f 100755 --- a/src/python/library/tritonclient/_client.py +++ b/src/python/library/tritonclient/_client.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from tritonclient.utils import raise_error +from tritonclient.hl import ModelClient class InferenceServerClientBase: def __init__(self): @@ -85,6 +86,26 @@ def unregister_plugin(self): self._plugin = None -class Client(InferenceServerClientBase): - def __init__(self) -> None: +# # Change url to 'http://localhost:8000' for utilizing HTTP client +# client = Client(url='grpc://loacalhost:8001') +# +# input_tensor_as_numpy = np.array(...) +# +# # Infer should be async similar to the exising Python APIs +# responses = client.model('simple').infer(inputs={'input': input_tensor_as_numpy}) +# +# for response in responses: +# numpy_array = np.asarray(response.outputs['output']) +# +# client.close() + + + +class Client: + def __init__(self, url: str) -> None: + self._client_url = url super().__init__() + + def model(self, name: str) -> ModelClient: + return ModelClient(url=self._client_url, model_name=name) + diff --git a/src/python/library/tritonclient/hl/README.md b/src/python/library/tritonclient/hl/README.md new file mode 100644 index 000000000..a5664f5cb --- /dev/null +++ b/src/python/library/tritonclient/hl/README.md @@ -0,0 +1,48 @@ + +Just for test install PyTriton client: + +```bash +pip install nvidia-pytriton +``` + +It is possible to test new client using PyTriton server: + +```python +import time +import numpy as np +from pytriton.model_config import ModelConfig, Tensor +from pytriton.triton import Triton, TritonConfig +from pytriton.decorators import batch + +@batch +def identity(input): + return {"output": input} + + +triton = Triton() +triton.bind( + model_name="identity", + infer_func=identity, + inputs=[Tensor(name="input", dtype=np.bytes_, shape=(1,))], + outputs=[Tensor(name="output", dtype=np.bytes_, shape=(1,))], + strict=False, +) +triton.run() +``` + + +You can test new client with simple request: + +```python +import numpy as np +from tritonclient._client import Client + +Client("localhost:8000").model("identity").infer(inputs={"input": np.char.encode([["a"]], "utf-8")} +``` + +Expected output: + +```python +{'output': array(['a'], dtype=' str: + # pytype complained about creating generator in __init__ method + # so we create it lazily + if getattr(self, "_request_id_generator", None) is None: + self._request_id_generator = itertools.count(0) + return str(next(self._request_id_generator)) + + def _get_init_extra_args(self): + timeout = self._inference_timeout_s # pytype: disable=attribute-error + # The inference timeout is used for both the HTTP and the GRPC protocols. However, + # the way the timeout is passed to the client differs depending on the protocol. + # For the HTTP protocol, the timeout is set in the ``__init__`` method as ``network_timeout`` + # and ``connection_timeout``. For the GRPC protocol, the timeout + # is passed to the infer method as ``client_timeout``. + # Both protocols support timeouts correctly and will raise an exception + # if the network request or the inference process takes longer than the timeout. + # This is a design choice of the underlying tritonclient library. + + if self._triton_url.scheme != "http": + return {} + + kwargs = { + # This value sets the maximum time allowed for each network request in both model loading and inference process + "network_timeout": timeout, + # This value sets the maximum time allowed for establishing a connection to the server. + # We use the inference timeout here instead of the init timeout because the init timeout + # is meant for waiting for the model to be ready. The connection timeout should be shorter + # than the init timeout because it only checks if connection is established (e.g. correct port) + "connection_timeout": timeout, + } + return kwargs + + def _monkey_patch_client(self): + pass + + def _get_model_config_extra_args(self): + # For the GRPC protocol, the timeout must be passed to the each request as client_timeout + # model_config doesn't yet support timeout but it is planned for the future + # grpc_network_timeout_s will be used for model_config + return {} + + def _handle_lazy_init(self): + raise NotImplementedError + + +def _run_once_per_lib(f): + def wrapper(_self): + if _self._triton_client_lib not in wrapper.patched: + wrapper.patched.add(_self._triton_client_lib) + return f(_self) + + wrapper.patched = set() + return wrapper + + +class ModelClient(BaseModelClient): + """Synchronous client for model deployed on the Triton Inference Server.""" + + def __init__( + self, + url: str, + model_name: str, + model_version: Optional[str] = None, + *, + lazy_init: bool = True, + init_timeout_s: Optional[float] = None, + inference_timeout_s: Optional[float] = None, + model_config: Optional[TritonModelConfig] = None, + ensure_model_is_ready: bool = True, + ): + """Inits ModelClient for given model deployed on the Triton Inference Server. + + If `lazy_init` argument is False, model configuration will be read + from inference server during initialization. + + Common usage: + + ```python + client = ModelClient("localhost", "BERT") + result_dict = client.infer_sample(input1_sample, input2_sample) + client.close() + ``` + + Client supports also context manager protocol: + + ```python + with ModelClient("localhost", "BERT") as client: + result_dict = client.infer_sample(input1_sample, input2_sample) + ``` + + The creation of client requires connection to the server and downloading model configuration. You can create client from existing client using the same class: + + ```python + client = ModelClient.from_existing_client(existing_client) + ``` + + Args: + url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'. + In case no scheme is provided http scheme will be used as default. + In case no port is provided default port for given scheme will be used - + 8001 for grpc scheme, 8000 for http scheme. + model_name: name of the model to interact with. + model_version: version of the model to interact with. + If model_version is None inference on latest model will be performed. + The latest versions of the model are numerically the greatest version numbers. + lazy_init: if initialization should be performed just before sending first request to inference server. + init_timeout_s: timeout for maximum waiting time in loop, which sends retry requests ask if model is ready. It is applied at initialization time only when `lazy_init` argument is False. Default is to do retry loop at first inference. + inference_timeout_s: timeout in seconds for the model inference process. + If non passed default 60 seconds timeout will be used. + For HTTP client it is not only inference timeout but any client request timeout + - get model config, is model loaded. For GRPC client it is only inference timeout. + model_config: model configuration. If not passed, it will be read from inference server during initialization. + ensure_model_is_ready: if model should be checked if it is ready before first inference request. + + Raises: + PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. + PyTritonClientTimeoutError: + if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. + PyTritonClientUrlParseError: In case of problems with parsing url. + """ + super().__init__( + url=url, + model_name=model_name, + model_version=model_version, + lazy_init=lazy_init, + init_timeout_s=init_timeout_s, + inference_timeout_s=inference_timeout_s, + model_config=model_config, + ensure_model_is_ready=ensure_model_is_ready, + ) + + def get_lib(self): + """Returns tritonclient library for given scheme.""" + return {"grpc": tritonclient.hl.lw.grpc, "http": tritonclient.hl.lw.http}[self._triton_url.scheme.lower()] + + def __enter__(self): + """Create context for using ModelClient as a context manager.""" + return self + + def __exit__(self, *_): + """Close resources used by ModelClient instance when exiting from the context.""" + self.close() + + def load_model(self, config: Optional[str] = None, files: Optional[dict] = None): + """Load model on the Triton Inference Server. + + Args: + config: str - Optional JSON representation of a model config provided for + the load request, if provided, this config will be used for + loading the model. + files: dict - Optional dictionary specifying file path (with "file:" prefix) in + the override model directory to the file content as bytes. + The files will form the model directory that the model will be + loaded from. If specified, 'config' must be provided to be + the model configuration of the override model directory. + """ + self._general_client.load_model(self._model_name, config=config, files=files) + + def unload_model(self): + """Unload model from the Triton Inference Server.""" + self._general_client.unload_model(self._model_name) + + def close(self): + """Close resources used by ModelClient. + + This method closes the resources used by the ModelClient instance, + including the Triton Inference Server connections. + Once this method is called, the ModelClient instance should not be used again. + """ + _LOGGER.debug("Closing ModelClient") + try: + if self._general_client is not None: + self._general_client.close() + if self._infer_client is not None: + self._infer_client.close() + self._general_client = None + self._infer_client = None + except Exception as e: + _LOGGER.error(f"Error while closing ModelClient resources: {e}") + raise e + + def wait_for_model(self, timeout_s: float): + """Wait for the Triton Inference Server and the deployed model to be ready. + + Args: + timeout_s: timeout in seconds to wait for the server and model to be ready. + + Raises: + PyTritonClientTimeoutError: If the server and model are not ready before the given timeout. + PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. + KeyboardInterrupt: If the hosting process receives SIGINT. + PyTritonClientClosedError: If the ModelClient is closed. + """ + if self._general_client is None: + raise PyTritonClientClosedError("ModelClient is closed") + wait_for_model_ready(self._general_client, self._model_name, self._model_version, timeout_s=timeout_s) + + @property + def is_batching_supported(self): + """Checks if model supports batching. + + Also waits for server to get into readiness state. + """ + return self.model_config.max_batch_size > 0 + + def wait_for_server(self, timeout_s: float): + """Wait for Triton Inference Server readiness. + + Args: + timeout_s: timeout to server get into readiness state. + + Raises: + PyTritonClientTimeoutError: If server is not in readiness state before given timeout. + KeyboardInterrupt: If hosting process receives SIGINT + """ + wait_for_server_ready(self._general_client, timeout_s=timeout_s) + + @property + def model_config(self) -> TritonModelConfig: + """Obtain the configuration of the model deployed on the Triton Inference Server. + + This method waits for the server to get into readiness state before obtaining the model configuration. + + Returns: + TritonModelConfig: configuration of the model deployed on the Triton Inference Server. + + Raises: + PyTritonClientTimeoutError: If the server and model are not in readiness state before the given timeout. + PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. + KeyboardInterrupt: If the hosting process receives SIGINT. + PyTritonClientClosedError: If the ModelClient is closed. + """ + if not self._model_config: + if self._general_client is None: + raise PyTritonClientClosedError("ModelClient is closed") + + self._model_config = get_model_config( + self._general_client, self._model_name, self._model_version, timeout_s=self._init_timeout_s + ) + return self._model_config + + def infer_sample( + self, + *inputs, + parameters: Optional[Dict[str, Union[str, int, bool]]] = None, + headers: Optional[Dict[str, Union[str, int, bool]]] = None, + **named_inputs, + ) -> Dict[str, np.ndarray]: + """Run synchronous inference on a single data sample. + + Typical usage: + + ```python + client = ModelClient("localhost", "MyModel") + result_dict = client.infer_sample(input1, input2) + client.close() + ``` + + Inference inputs can be provided either as positional or keyword arguments: + + ```python + result_dict = client.infer_sample(input1, input2) + result_dict = client.infer_sample(a=input1, b=input2) + ``` + + Args: + *inputs: Inference inputs provided as positional arguments. + parameters: Custom inference parameters. + headers: Custom inference headers. + **named_inputs: Inference inputs provided as named arguments. + + Returns: + Dictionary with inference results, where dictionary keys are output names. + + Raises: + PyTritonClientValueError: If mixing of positional and named arguments passing detected. + PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or + inference request time exceeds `inference_timeout_s`. + PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. + PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side. + """ + _verify_inputs_args(inputs, named_inputs) + _verify_parameters(parameters) + _verify_parameters(headers) + + if self.is_batching_supported: + if inputs: + inputs = tuple(data[np.newaxis, ...] for data in inputs) + elif named_inputs: + named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()} + + result = self._infer(inputs or named_inputs, parameters, headers) + + return self._debatch_result(result) + + def infer_batch( + self, + *inputs, + parameters: Optional[Dict[str, Union[str, int, bool]]] = None, + headers: Optional[Dict[str, Union[str, int, bool]]] = None, + **named_inputs, + ) -> Dict[str, np.ndarray]: + """Run synchronous inference on batched data. + + Typical usage: + + ```python + client = ModelClient("localhost", "MyModel") + result_dict = client.infer_batch(input1, input2) + client.close() + ``` + + Inference inputs can be provided either as positional or keyword arguments: + + ```python + result_dict = client.infer_batch(input1, input2) + result_dict = client.infer_batch(a=input1, b=input2) + ``` + + Args: + *inputs: Inference inputs provided as positional arguments. + parameters: Custom inference parameters. + headers: Custom inference headers. + **named_inputs: Inference inputs provided as named arguments. + + Returns: + Dictionary with inference results, where dictionary keys are output names. + + Raises: + PyTritonClientValueError: If mixing of positional and named arguments passing detected. + PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or + inference request time exceeds `inference_timeout_s`. + PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. + PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side. + PyTritonClientModelDoesntSupportBatchingError: If the model doesn't support batching. + PyTritonClientValueError: if mixing of positional and named arguments passing detected. + PyTritonClientTimeoutError: + in case of first method call, `lazy_init` argument is False + and wait time for server and model being ready exceeds `init_timeout_s` or + inference time exceeds `inference_timeout_s` passed to `__init__`. + PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. + PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side, + """ + _verify_inputs_args(inputs, named_inputs) + _verify_parameters(parameters) + _verify_parameters(headers) + + if not self.is_batching_supported: + raise PyTritonClientModelDoesntSupportBatchingError( + f"Model {self.model_config.model_name} doesn't support batching - use infer_sample method instead" + ) + + return self._infer(inputs or named_inputs, parameters, headers) + + def infer(self, inputs): + """Run synchronous batch inference using a single dictionary of inputs.""" + return self.infer_batch(**inputs) + + + def _wait_and_init_model_config(self, init_timeout_s: float): + if self._general_client is None: + raise PyTritonClientClosedError("ModelClient is closed") + + should_finish_before_s = time.time() + init_timeout_s + self.wait_for_model(init_timeout_s) + self._model_ready = True + timeout_s = max(0.0, should_finish_before_s - time.time()) + self._model_config = get_model_config( + self._general_client, self._model_name, self._model_version, timeout_s=timeout_s + ) + + def _create_request(self, inputs: _IOType): + if self._infer_client is None: + raise PyTritonClientClosedError("ModelClient is closed") + + if not self._model_ready: + self._wait_and_init_model_config(self._init_timeout_s) + + if isinstance(inputs, Tuple): + inputs = {input_spec.name: input_data for input_spec, input_data in zip(self.model_config.inputs, inputs)} + + inputs_wrapped = [] + + # to help pytype to obtain variable type + inputs: Dict[str, np.ndarray] + + for input_name, input_data in inputs.items(): + if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes): + raise RuntimeError( + f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \ + \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}" + ) + if input_data.dtype.type == np.str_: + raise RuntimeError( + "Unicode inputs are not supported. " + f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))." + ) + triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype) + infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype) + infer_input.set_data_from_numpy(input_data) + inputs_wrapped.append(infer_input) + + outputs_wrapped = [ + self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in self.model_config.outputs + ] + return inputs_wrapped, outputs_wrapped + + def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: + import tritonclient.http + import tritonclient.utils + if self.model_config.decoupled: + raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.") + + inputs_wrapped, outputs_wrapped = self._create_request(inputs) + + try: + _LOGGER.debug("Sending inference request to Triton Inference Server") + response = self._infer_client.infer( + model_name=self._model_name, + model_version=self._model_version or "", + inputs=inputs_wrapped, + headers=headers, + outputs=outputs_wrapped, + request_id=self._next_request_id, + parameters=parameters, + **self._get_infer_extra_args(), + ) + except tritonclient.utils.InferenceServerException as e: + # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout + if "Deadline Exceeded" in e.message(): + raise PyTritonClientTimeoutError( + f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" + ) from e + + raise PyTritonClientInferenceServerError( + f"Error occurred during inference request. Message: {e.message()}" + ) from e + except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout + message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" + _LOGGER.error(message) + raise PyTritonClientTimeoutError(message) from e + except OSError as e: # tritonclient.http raises socket.error for connection error + message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" + _LOGGER.error(message) + raise PyTritonClientTimeoutError(message) from e + + if isinstance(response, tritonclient.http.InferResult): + outputs = { + output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"] + } + else: + outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs} + + return outputs + + def _get_numpy_result(self, result): + import tritonclient.grpc + if isinstance(result, tritonclient.grpc.InferResult): + result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs} + else: + result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]} + return result + + def _debatch_result(self, result): + if self.is_batching_supported: + result = {name: data[0] for name, data in result.items()} + return result + + def _handle_lazy_init(self): + if not self._lazy_init: + self._wait_and_init_model_config(self._init_timeout_s) + + def _get_infer_extra_args(self): + if self._triton_url.scheme == "http": + return {} + # For the GRPC protocol, the timeout is passed to the infer method as client_timeout + # This timeout applies to the whole inference process and each network request + + # The ``infer`` supports also timeout argument for both GRPC and HTTP. + # It is applied at server side and supported only for dynamic batching. + # However, it is not used here yet and planned for future release + kwargs = {"client_timeout": self._inference_timeout_s} + return kwargs + + @_run_once_per_lib + def _monkey_patch_client(self): + """Monkey patch InferenceServerClient to catch error in __del__.""" + _LOGGER.info(f"Patch ModelClient {self._triton_url.scheme}") + if not hasattr(self._triton_client_lib.InferenceServerClient, "__del__"): + return + + old_del = self._triton_client_lib.InferenceServerClient.__del__ + + def _monkey_patched_del(self): + """Monkey patched del.""" + try: + old_del(self) + except gevent.exceptions.InvalidThreadUseError: + _LOGGER.info("gevent.exceptions.InvalidThreadUseError in __del__ of InferenceServerClient") + except Exception as e: + _LOGGER.error("Exception in __del__ of InferenceServerClient: %s", e) + + self._triton_client_lib.InferenceServerClient.__del__ = _monkey_patched_del + + +class InferenceServerClientBase: + def __init__(self): + self._plugin = None + + def _call_plugin(self, request): + """Called by the subclasses before sending a request to the + network. + """ + if self._plugin != None: + self._plugin(request) + + def register_plugin(self, plugin): + """Register a Client Plugin. + + Parameters + ---------- + plugin : InferenceServerClientPlugin + A client plugin + + Raises + ------ + InferenceServerException + If a plugin is already registered. + """ + + if self._plugin is None: + self._plugin = plugin + else: + raise_error( + "A plugin is already registered. Please " + "unregister the previous plugin first before" + " registering a new plugin." + ) + + def plugin(self): + """Retrieve the registered plugin if any. + + Returns + ------ + InferenceServerClientPlugin or None + """ + return self._plugin + + def unregister_plugin(self): + """Unregister a plugin. + + Raises + ------ + InferenceServerException + If no plugin has been registered. + """ + if self._plugin is None: + raise_error("No plugin has been registered.") + + self._plugin = None + + +# # Change url to 'http://localhost:8000' for utilizing HTTP client +# client = Client(url='grpc://loacalhost:8001') +# +# input_tensor_as_numpy = np.array(...) +# +# # Infer should be async similar to the exising Python APIs +# responses = client.model('simple').infer(inputs={'input': input_tensor_as_numpy}) +# +# for response in responses: +# numpy_array = np.asarray(response.outputs['output']) +# +# client.close() + + +class Client(InferenceServerClientBase): + def __init__(self, url: str) -> None: + self._client_url = url + super().__init__() + + def model(self, name: str) -> ModelClient: + return ModelClient(url=self._client_url, model_name=name) + diff --git a/src/python/library/tritonclient/hl/common.py b/src/python/library/tritonclient/hl/common.py new file mode 100644 index 000000000..1d58024be --- /dev/null +++ b/src/python/library/tritonclient/hl/common.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common structures for internal and external ModelConfig.""" + +import dataclasses +import enum +from typing import Dict, Optional + + +class DeviceKind(enum.Enum): + """Device kind for model deployment. + + Args: + KIND_AUTO: Automatically select the device for model deployment. + KIND_CPU: Model is deployed on CPU. + KIND_GPU: Model is deployed on GPU. + """ + + KIND_AUTO = "KIND_AUTO" + KIND_CPU = "KIND_CPU" + KIND_GPU = "KIND_GPU" + + +class TimeoutAction(enum.Enum): + """Timeout action definition for timeout_action QueuePolicy field. + + Args: + REJECT: Reject the request and return error message accordingly. + DELAY: Delay the request until all other requests at the same (or higher) priority levels + that have not reached their timeouts are processed. + """ + + REJECT = "REJECT" + DELAY = "DELAY" + + +@dataclasses.dataclass +class QueuePolicy: + """Model queue policy configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1037 + + Args: + timeout_action: The action applied to timed-out request. + default_timeout_microseconds: The default timeout for every request, in microseconds. + allow_timeout_override: Whether individual request can override the default timeout value. + max_queue_size: The maximum queue size for holding requests. + """ + + timeout_action: TimeoutAction = TimeoutAction.REJECT + default_timeout_microseconds: int = 0 + allow_timeout_override: bool = False + max_queue_size: int = 0 + + +@dataclasses.dataclass +class DynamicBatcher: + """Dynamic batcher configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1104 + + Args: + max_queue_delay_microseconds: The maximum time, in microseconds, a request will be delayed in + the scheduling queue to wait for additional requests for batching. + preferred_batch_size: Preferred batch sizes for dynamic batching. + preserve_ordering : Should the dynamic batcher preserve the ordering of responses to + match the order of requests received by the scheduler. + priority_levels: The number of priority levels to be enabled for the model. + default_priority_level: The priority level used for requests that don't specify their priority. + default_queue_policy: The default queue policy used for requests. + priority_queue_policy: Specify the queue policy for the priority level. + """ + + max_queue_delay_microseconds: int = 0 + preferred_batch_size: Optional[list] = None + preserve_ordering: bool = False + priority_levels: int = 0 + default_priority_level: int = 0 + default_queue_policy: Optional[QueuePolicy] = None + priority_queue_policy: Optional[Dict[int, QueuePolicy]] = None diff --git a/src/python/library/tritonclient/hl/constants.py b/src/python/library/tritonclient/hl/constants.py new file mode 100644 index 000000000..49f8723c8 --- /dev/null +++ b/src/python/library/tritonclient/hl/constants.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# noqa: D104 +"""Constants for pytriton.""" + +import os +import pathlib + +DEFAULT_HTTP_PORT = 8000 +DEFAULT_GRPC_PORT = 8001 +DEFAULT_METRICS_PORT = 8002 +TRITON_LOCAL_IP = "127.0.0.1" +TRITON_CONTEXT_FIELD_NAME = "triton_context" +TRITON_PYTHON_BACKEND_INTERPRETER_DIRNAME = "python_backend_interpreter" +DEFAULT_TRITON_STARTUP_TIMEOUT_S = 120 +CREATE_TRITON_CLIENT_TIMEOUT_S = 10 + +__DEFAULT_PYTRITON_HOME = os.path.join(os.getenv("XDG_CACHE_HOME", "$HOME/.cache"), "pytriton") +__PYTRITON_HOME = os.path.expanduser(os.path.expandvars(os.getenv("PYTRITON_HOME", __DEFAULT_PYTRITON_HOME))) +PYTRITON_HOME = pathlib.Path(__PYTRITON_HOME).resolve().absolute() diff --git a/src/python/library/tritonclient/hl/exceptions.py b/src/python/library/tritonclient/hl/exceptions.py new file mode 100644 index 000000000..6619b4a31 --- /dev/null +++ b/src/python/library/tritonclient/hl/exceptions.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exceptions thrown in pytriton.client module.""" + + +class PyTritonClientError(Exception): + """Generic pytriton client exception.""" + + def __init__(self, message: str): + """Initialize exception with message. + + Args: + message: Error message + """ + self._message = message + + def __str__(self) -> str: + """String representation of error. + + Returns: + Message content + """ + return self._message + + @property + def message(self): + """Get the exception message. + + Returns: + The message associated with this exception, or None if no message. + + """ + return self._message + + +class PyTritonClientValueError(PyTritonClientError): + """Generic error raised in case of incorrect values are provided into API.""" + + pass + + +class PyTritonClientInvalidUrlError(PyTritonClientValueError): + """Error raised when provided Triton Inference Server url is invalid.""" + + pass + + +class PyTritonClientTimeoutError(PyTritonClientError): + """Timeout occurred during communication with the Triton Inference Server.""" + + pass + + +class PyTritonClientModelUnavailableError(PyTritonClientError): + """Model with given name and version is unavailable on the given Triton Inference Server.""" + + pass + + +class PyTritonClientClosedError(PyTritonClientError): + """Error raised in case of trying to use closed client.""" + + pass + + +class PyTritonClientModelDoesntSupportBatchingError(PyTritonClientError): + """Error raised in case of trying to infer batch on model not supporting batching.""" + + pass + + +class PyTritonClientInferenceServerError(PyTritonClientError): + """Error raised in case of error on inference callable or Triton Inference Server side.""" + + pass + + +class PyTritonClientQueueFullError(PyTritonClientError): + """Error raised in case of trying to push request to full queue.""" + + pass diff --git a/src/python/library/tritonclient/hl/lw/__init__.py b/src/python/library/tritonclient/hl/lw/__init__.py new file mode 100755 index 000000000..41e3b269d --- /dev/null +++ b/src/python/library/tritonclient/hl/lw/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/library/tritonclient/hl/lw/grpc.py b/src/python/library/tritonclient/hl/lw/grpc.py new file mode 100755 index 000000000..559d58f14 --- /dev/null +++ b/src/python/library/tritonclient/hl/lw/grpc.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""New Tritonclient GRPC not implemented yet.""" + +class InferenceServerClient: + """Client to perform grpc communication with the Triton. + """ + pass \ No newline at end of file diff --git a/src/python/library/tritonclient/hl/lw/http.py b/src/python/library/tritonclient/hl/lw/http.py new file mode 100755 index 000000000..92905d6dd --- /dev/null +++ b/src/python/library/tritonclient/hl/lw/http.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import numpy as np + +from tritonclient.hl.lw.infer_input import InferInput + +class InferOutput(): + """Output from the inference server.""" + def __init__(self, name): + """Initialize the output.""" + self.name = name + + +class InferResponse(): + """Response from the inference server.""" + + def __init__(self, outputs): + """Initialize the response.""" + self._rest_outputs = outputs + self.outputs = [InferOutput(response['name']) for response in outputs['outputs']] + + def get_response(self): + """Get the response.""" + return self + + def as_numpy(self, name): + """Get the response as numpy.""" + for response in self._rest_outputs['outputs']: + if response['name'] == name: + return np.array(response['data']) + return None + +class InferenceServerClient(): + """Client to perform http communication with the Triton. + """ + + def __init__(self, url, **kwargs): + self.url = url + + def is_server_ready(self): + """Check if the server is ready. + """ + response = requests.get("http://" + self.url + "/v2/health/live") + return response.status_code == 200 + + def is_server_live(self): + """Check if the server is ready. + """ + response = requests.get("http://" + self.url + "/v2/health/ready") + return response.status_code == 200 + + def is_model_ready(self, model_name, model_version): + """Check if the model is ready. + """ + model_version = model_version if model_version else "1" + request = "http://" + self.url + "/v2/models/{}/versions/{}/ready".format(model_name, model_version) + response = requests.get(request) + + return response.status_code == 200 + + def get_model_config(self, model_name, model_version): + """Get the model configuration. + """ + model_version = model_version if model_version else "1" + request = "http://" + self.url + "/v2/models/{}/versions/{}/config".format(model_name, model_version) + response = requests.get(request) + + return response.json() + + # In [13]: import requests + # ...: import json + # ...: + # ...: # Define the server URL + # ...: server_url = "http://localhost:8000/v2/models/identity/versions/1/infer" + # ...: + # ...: # Prepare the input data + # ...: input_string = "Hello Triton Inference Server!" + # ...: + # ...: # Triton requires the data to be in a specific format + # ...: inputs = [ + # ...: { + # ...: "name": "input", + # ...: "shape": [1, 1], # Adjust the shape to include the batch dimension + # ...: "datatype": "BYTES", + # ...: "data": [input_string] + # ...: } + # ...: ] + # ...: + # ...: # Prepare the request payload + # ...: payload = { + # ...: "inputs": inputs + # ...: } + # ...: + # ...: # Send the request + # ...: response = requests.post(server_url, json=payload) + # ...: + # ...: # Check the response status + # ...: if response.status_code == 200: + # ...: result = response.json() + # ...: print("Inference result:", result) + # ...: else: + # ...: print("Failed to get inference result:", response.status_code, response.text) + # ...: + # Inference result: {'model_name': 'identity', 'model_version': '1', 'outputs': [{'name': 'output', 'datatype': 'BYTES', 'shape': [1, 1], 'data': ['Hello Triton Inference Server!']}]} + # In [14]: + + def infer(self, model_name, model_version, inputs, headers, outputs, request_id, parameters): + """Perform inference. + """ + model_version = model_version if model_version else "1" + request = "http://" + self.url + "/v2/models/{}/versions/{}/infer".format(model_name, model_version) + print(request) + + inputs_for_json = [input_value.to_dict() for input_value in inputs] + + print(inputs_for_json) + + ## TODO: Support setting outputs, request_id, parameters and headers + response = requests.post(request, json={"inputs": inputs_for_json}) + + return InferResponse(response.json()) + + +class InferRequestedOutput(): + def __init__(self, name): + self.name = name diff --git a/src/python/library/tritonclient/hl/lw/infer_input.py b/src/python/library/tritonclient/hl/lw/infer_input.py new file mode 100755 index 000000000..076458bb3 --- /dev/null +++ b/src/python/library/tritonclient/hl/lw/infer_input.py @@ -0,0 +1,355 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import numpy as np + +def raise_error(message): + """Raise an InferenceServerException with the specified message.""" + raise Exception(message) + +def triton_to_np_dtype(dtype): + """Converts a Triton dtype to a numpy dtype.""" + if dtype == "BOOL": + return bool + elif dtype == "INT8": + return np.int8 + elif dtype == "INT16": + return np.int16 + elif dtype == "INT32": + return np.int32 + elif dtype == "INT64": + return np.int64 + elif dtype == "UINT8": + return np.uint8 + elif dtype == "UINT16": + return np.uint16 + elif dtype == "UINT32": + return np.uint32 + elif dtype == "UINT64": + return np.uint64 + elif dtype == "FP16": + return np.float16 + elif dtype == "FP32" or dtype == "BF16": + return np.float32 + elif dtype == "FP64": + return np.float64 + elif dtype == "BYTES": + return np.object_ + return None + +def serialize_byte_tensor(input_tensor): + """ + Serializes a bytes tensor into a flat numpy array of length prepended + bytes. The numpy array should use dtype of np.object. For np.bytes, + numpy will remove trailing zeros at the end of byte sequence and because + of this it should be avoided. + + Parameters + ---------- + input_tensor : np.array + The bytes tensor to serialize. + + Returns + ------- + serialized_bytes_tensor : np.array + The 1-D numpy array of type uint8 containing the serialized bytes in row-major form. + + Raises + ------ + InferenceServerException + If unable to serialize the given tensor. + """ + + if input_tensor.size == 0: + return np.empty([0], dtype=np.object_) + + # If the input is a tensor of string/bytes objects, then must flatten those into + # a 1-dimensional array containing the 4-byte byte size followed by the + # actual element bytes. All elements are concatenated together in row-major + # order. + + if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_): + raise_error("cannot serialize bytes tensor: invalid datatype") + + flattened_ls = [] + # 'C' order is row-major. + for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"): + # If directly passing bytes to BYTES type, + # don't convert it to str as Python will encode the + # bytes which may distort the meaning + if input_tensor.dtype == np.object_: + if type(obj.item()) == bytes: + s = obj.item() + else: + s = str(obj.item()).encode("utf-8") + else: + s = obj.item() + flattened_ls.append(struct.pack(" 0: + for obj in np.nditer( + input_tensor, flags=["refs_ok"], order="C" + ): + # We need to convert the object to string using utf-8, + # if we want to use the binary_data=False. JSON requires + # the input to be a UTF-8 string. + if input_tensor.dtype == np.object_: + if type(obj.item()) == bytes: + self._data.append(str(obj.item(), encoding="utf-8")) + else: + self._data.append(str(obj.item())) + else: + self._data.append(str(obj.item(), encoding="utf-8")) + except UnicodeDecodeError: + raise_error( + f'Failed to encode "{obj.item()}" using UTF-8. Please use binary_data=True, if' + " you want to pass a byte array." + ) + else: + self._data = [val.item() for val in input_tensor.flatten()] + else: + self._data = None + if self._datatype == "BYTES": + serialized_output = serialize_byte_tensor(input_tensor) + if serialized_output.size > 0: + self._raw_data = serialized_output.item() + else: + self._raw_data = b"" + elif self._datatype == "BF16": + serialized_output = serialize_bf16_tensor(input_tensor) + if serialized_output.size > 0: + self._raw_data = serialized_output.item() + else: + self._raw_data = b"" + else: + self._raw_data = input_tensor.tobytes() + self._parameters["binary_data_size"] = len(self._raw_data) + return self + + + def _get_binary_data(self): + """Returns the raw binary data if available + + Returns + ------- + bytes + The raw data for the input tensor + """ + return self._raw_data + + def _get_tensor(self): + """Retrieve the underlying input as json dict. + + Returns + ------- + dict + The underlying tensor specification as dict + """ + tensor = {"name": self._name, "shape": self._shape, "datatype": self._datatype} + if self._parameters: + tensor["parameters"] = self._parameters + + if ( + self._parameters.get("shared_memory_region") is None + and self._raw_data is None + ): + if self._data is not None: + tensor["data"] = self._data + return tensor + + def to_dict(self): + return { + "name": self.name(), + "shape": self.shape(), + "datatype": self.datatype(), + "data": self._get_tensor()["data"], + } + +class InferRequestedOutput(): + def __init__(self, name): + self.name = name diff --git a/src/python/library/tritonclient/hl/lw/utils.py b/src/python/library/tritonclient/hl/lw/utils.py new file mode 100755 index 000000000..0a15b1911 --- /dev/null +++ b/src/python/library/tritonclient/hl/lw/utils.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class InferenceServerException(Exception): + pass \ No newline at end of file diff --git a/src/python/library/tritonclient/hl/model_config/__init__.py b/src/python/library/tritonclient/hl/model_config/__init__.py new file mode 100644 index 000000000..9698bf59e --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# noqa: D104 +from .common import DeviceKind, DynamicBatcher, QueuePolicy, TimeoutAction # noqa: F401 +from .model_config import ModelConfig # noqa: F401 +from .tensor import Tensor # noqa: F401 diff --git a/src/python/library/tritonclient/hl/model_config/common.py b/src/python/library/tritonclient/hl/model_config/common.py new file mode 100644 index 000000000..1d58024be --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/common.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common structures for internal and external ModelConfig.""" + +import dataclasses +import enum +from typing import Dict, Optional + + +class DeviceKind(enum.Enum): + """Device kind for model deployment. + + Args: + KIND_AUTO: Automatically select the device for model deployment. + KIND_CPU: Model is deployed on CPU. + KIND_GPU: Model is deployed on GPU. + """ + + KIND_AUTO = "KIND_AUTO" + KIND_CPU = "KIND_CPU" + KIND_GPU = "KIND_GPU" + + +class TimeoutAction(enum.Enum): + """Timeout action definition for timeout_action QueuePolicy field. + + Args: + REJECT: Reject the request and return error message accordingly. + DELAY: Delay the request until all other requests at the same (or higher) priority levels + that have not reached their timeouts are processed. + """ + + REJECT = "REJECT" + DELAY = "DELAY" + + +@dataclasses.dataclass +class QueuePolicy: + """Model queue policy configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1037 + + Args: + timeout_action: The action applied to timed-out request. + default_timeout_microseconds: The default timeout for every request, in microseconds. + allow_timeout_override: Whether individual request can override the default timeout value. + max_queue_size: The maximum queue size for holding requests. + """ + + timeout_action: TimeoutAction = TimeoutAction.REJECT + default_timeout_microseconds: int = 0 + allow_timeout_override: bool = False + max_queue_size: int = 0 + + +@dataclasses.dataclass +class DynamicBatcher: + """Dynamic batcher configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1104 + + Args: + max_queue_delay_microseconds: The maximum time, in microseconds, a request will be delayed in + the scheduling queue to wait for additional requests for batching. + preferred_batch_size: Preferred batch sizes for dynamic batching. + preserve_ordering : Should the dynamic batcher preserve the ordering of responses to + match the order of requests received by the scheduler. + priority_levels: The number of priority levels to be enabled for the model. + default_priority_level: The priority level used for requests that don't specify their priority. + default_queue_policy: The default queue policy used for requests. + priority_queue_policy: Specify the queue policy for the priority level. + """ + + max_queue_delay_microseconds: int = 0 + preferred_batch_size: Optional[list] = None + preserve_ordering: bool = False + priority_levels: int = 0 + default_priority_level: int = 0 + default_queue_policy: Optional[QueuePolicy] = None + priority_queue_policy: Optional[Dict[int, QueuePolicy]] = None diff --git a/src/python/library/tritonclient/hl/model_config/model_config.py b/src/python/library/tritonclient/hl/model_config/model_config.py new file mode 100644 index 000000000..9d446478e --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/model_config.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model configurations. + +Dataclasses with specialized deployment paths for models on Triton. The purpose of this module is to provide clear options +to configure models of given types. + +The dataclasses are exposed in the user API. +""" + +import dataclasses + +#from tritonclient.hl.model_config import DynamicBatcher + + +@dataclasses.dataclass +class ModelConfig: + """Additional model configuration for running model through Triton Inference Server. + + Args: + batching: Flag to enable/disable batching for model. + max_batch_size: The maximal batch size that would be handled by model. + batcher: Configuration of Dynamic Batching for the model. + response_cache: Flag to enable/disable response cache for the model + decoupled: Flag to enable/disable decoupled from requests execution + """ + + batching: bool = True + max_batch_size: int = 4 + #batcher: DynamicBatcher = dataclasses.field(default_factory=DynamicBatcher) + response_cache: bool = False + decoupled: bool = False diff --git a/src/python/library/tritonclient/hl/model_config/parser.py b/src/python/library/tritonclient/hl/model_config/parser.py new file mode 100644 index 000000000..83d74cb9f --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/parser.py @@ -0,0 +1,258 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ModelConfigParser class definition. + +Provide functionality to parse the Triton model configuration stored in file or form of dictionary into the object of +class ModelConfig. + + Examples of use: + + # Parse from dict + model_config = ModelConfigParser.from_dict(model_config_dict) + + # Parse from file + model_config = ModelConfigParser.from_file("/path/to/config.pbtxt") + +""" + +import json +import logging +import pathlib +from typing import Dict + +import numpy as np +from google.protobuf import json_format, text_format # pytype: disable=pyi-error + +#from pytriton.exceptions import PyTritonModelConfigError + +from .common import QueuePolicy, TimeoutAction +from .triton_model_config import DeviceKind, DynamicBatcher, ResponseCache, TensorSpec, TritonModelConfig + +try: + import tritonclient.grpc as grpc_client + from tritonclient import utils as client_utils # noqa: F401 +except ImportError: + try: + import tritonclientutils as client_utils # noqa: F401 + import tritongrpcclient as grpc_client + except ImportError: + client_utils = None + grpc_client = None + +LOGGER = logging.getLogger(__name__) + + +class ModelConfigParser: + """Provide functionality to parse dictionary or file to ModelConfig object.""" + + @classmethod + def from_dict(cls, model_config_dict: Dict) -> TritonModelConfig: + """Create ModelConfig from configuration stored in dictionary. + + Args: + model_config_dict: Dictionary with model config + + Returns: + A ModelConfig object with data parsed from the dictionary + """ + LOGGER.debug(f"Parsing Triton config model from dict: \n{json.dumps(model_config_dict, indent=4)}") + + if model_config_dict.get("max_batch_size", 0) > 0: + batching = True + else: + batching = False + + dynamic_batcher_config = model_config_dict.get("dynamic_batching") + if dynamic_batcher_config is not None: + batcher = cls._parse_dynamic_batching(dynamic_batcher_config) + else: + batcher = None + + instance_group = { + DeviceKind(entry["kind"]): entry.get("count") for entry in model_config_dict.get("instance_group", []) + } + + decoupled = model_config_dict.get("model_transaction_policy", {}).get("decoupled", False) + + backend_parameters_config = model_config_dict.get("parameters", []) + if isinstance(backend_parameters_config, list): + # If the backend_parameters_config is a list of strings, use them as keys with empty values + LOGGER.debug(f"backend_parameters_config is a list of strings: {backend_parameters_config}") + backend_parameters = {name: "" for name in backend_parameters_config} + elif isinstance(backend_parameters_config, dict): + # If the backend_parameters_config is a dictionary, use the key and "string_value" fields as key-value pairs + LOGGER.debug(f"backend_parameters_config is a dictionary: {backend_parameters_config}") + backend_parameters = { + name: backend_parameters_config[name]["string_value"] for name in backend_parameters_config + } + else: + # Otherwise, raise an error + LOGGER.error( + f"Invalid type {type(backend_parameters_config)} for backend_parameters_config: {backend_parameters_config}" + ) + raise TypeError(f"Invalid type for backend_parameters_config: {type(backend_parameters_config)}") + + inputs = [ + cls.rewrite_io_spec(item, "input", idx) for idx, item in enumerate(model_config_dict.get("input", [])) + ] or None + outputs = [ + cls.rewrite_io_spec(item, "output", idx) for idx, item in enumerate(model_config_dict.get("output", [])) + ] or None + + response_cache_config = model_config_dict.get("response_cache") + if response_cache_config: + response_cache = cls._parse_response_cache(response_cache_config) + else: + response_cache = None + + return TritonModelConfig( + model_name=model_config_dict["name"], + batching=batching, + max_batch_size=model_config_dict.get("max_batch_size", 0), + batcher=batcher, + inputs=inputs, + outputs=outputs, + instance_group=instance_group, + decoupled=decoupled, + backend_parameters=backend_parameters, + response_cache=response_cache, + ) + + @classmethod + def from_file(cls, *, config_path: pathlib.Path) -> TritonModelConfig: + """Create ModelConfig from configuration stored in file. + + Args: + config_path: location of file with model config + + Returns: + A ModelConfig object with data parsed from the file + """ + from tritonclient.grpc import model_config_pb2 # pytype: disable=import-error + + LOGGER.debug(f"Parsing Triton config model config_path={config_path}") + + with config_path.open("r") as config_file: + payload = config_file.read() + model_config_proto = text_format.Parse(payload, model_config_pb2.ModelConfig()) + + model_config_dict = json_format.MessageToDict(model_config_proto, preserving_proto_field_name=True) + return ModelConfigParser.from_dict(model_config_dict=model_config_dict) + + @classmethod + def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec: + """Rewrite the IO Spec provided in form of dictionary to TensorSpec. + + Args: + item: IO data for input + io_type: Type of the IO (input or output) + idx: Index of IO + + Returns: + TensorSpec with input or output data + """ + name = item.get("name") + if not name: + raise PyTritonModelConfigError(f"Name for {io_type} at index {idx} not provided.") + + data_type = item.get("data_type") + if not data_type: + raise PyTritonModelConfigError(f"Data type for {io_type} with name `{name}` not defined.") + + data_type_val = data_type.split("_") + if len(data_type_val) != 2: + raise PyTritonModelConfigError( + f"Invalid data type `{data_type}` for {io_type} with name `{name}` not defined. " + "The expected name is TYPE_{type}." + ) + + data_type = data_type_val[1] + if data_type == "STRING": + dtype = np.bytes_ + else: + dtype = client_utils.triton_to_np_dtype(data_type) + if dtype is None: + raise PyTritonModelConfigError(f"Unsupported data type `{data_type}` for {io_type} with name `{name}`") + + dtype = np.dtype("bool") if dtype is bool else dtype + + dims = item.get("dims", []) + if not dims: + raise PyTritonModelConfigError(f"Dimension for {io_type} with name `{name}` not defined.") + + shape = tuple(int(s) for s in dims) + + optional = item.get("optional", False) + return TensorSpec(name=item["name"], shape=shape, dtype=dtype, optional=optional) + + @classmethod + def _parse_dynamic_batching(cls, dynamic_batching_config: Dict) -> DynamicBatcher: + """Parse config to create DynamicBatcher object. + + Args: + dynamic_batching_config: Configuration of dynamic batcher from config + + Returns: + DynamicBatcher object with configuration + """ + default_queue_policy = None + default_queue_policy_config = dynamic_batching_config.get("default_queue_policy") + if default_queue_policy_config: + default_queue_policy = QueuePolicy( + timeout_action=TimeoutAction( + default_queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value) + ), + default_timeout_microseconds=int(default_queue_policy_config.get("default_timeout_microseconds", 0)), + allow_timeout_override=bool(default_queue_policy_config.get("allow_timeout_override", False)), + max_queue_size=int(default_queue_policy_config.get("max_queue_size", 0)), + ) + + priority_queue_policy = None + priority_queue_policy_config = dynamic_batching_config.get("priority_queue_policy") + if priority_queue_policy_config: + priority_queue_policy = {} + for priority, queue_policy_config in priority_queue_policy_config.items(): + queue_policy = QueuePolicy( + timeout_action=TimeoutAction(queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)), + default_timeout_microseconds=int(queue_policy_config.get("default_timeout_microseconds", 0)), + allow_timeout_override=bool(queue_policy_config.get("allow_timeout_override", False)), + max_queue_size=int(queue_policy_config.get("max_queue_size", 0)), + ) + priority_queue_policy[int(priority)] = queue_policy + + batcher = DynamicBatcher( + preferred_batch_size=dynamic_batching_config.get("preferred_batch_size"), + max_queue_delay_microseconds=int(dynamic_batching_config.get("max_queue_delay_microseconds", 0)), + preserve_ordering=bool(dynamic_batching_config.get("preserve_ordering", False)), + priority_levels=int(dynamic_batching_config.get("priority_levels", 0)), + default_priority_level=int(dynamic_batching_config.get("default_priority_level", 0)), + default_queue_policy=default_queue_policy, + priority_queue_policy=priority_queue_policy, + ) + return batcher + + @classmethod + def _parse_response_cache(cls, response_cache_config: Dict) -> ResponseCache: + """Parse config for response cache. + + Args: + response_cache_config: response cache configuration + + Returns: + ResponseCache object with configuration + """ + response_cache = ResponseCache( + enable=bool(response_cache_config["enable"]), + ) + return response_cache diff --git a/src/python/library/tritonclient/hl/model_config/tensor.py b/src/python/library/tritonclient/hl/model_config/tensor.py new file mode 100644 index 000000000..ded9050c6 --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/tensor.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensor object definition. + +Describe the model input or output. + + Examples of use: + + # Minimal constructors + tensor = Tensor(dtype=np.bytes_, shape=(-1,)) + tensor = Tensor(dtype=np.float32, shape=(-1,)) + + # Type definition from existing object + a = np.array([1, 2, 3, 4]) + tensor = Tensor(dtype=a.dtype, shape=(-1,)) + + # Custom name + tensor = Tensor(name="data", dtype=np.float32, shape=(16,)) +""" + +import dataclasses +from typing import Optional, Type, Union + +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class Tensor: + """Model input and output definition for Triton deployment. + + Args: + shape: Shape of the input/output tensor. + dtype: Data type of the input/output tensor. + name: Name of the input/output of model. + optional: Flag to mark if input is optional. + """ + + shape: tuple + dtype: Union[np.dtype, Type[np.dtype], Type[object]] + name: Optional[str] = None + optional: Optional[bool] = False + + def __post_init__(self): + """Override object values on post init or field override.""" + if isinstance(self.dtype, np.dtype): + object.__setattr__(self, "dtype", self.dtype.type) # pytype: disable=attribute-error diff --git a/src/python/library/tritonclient/hl/model_config/triton_model_config.py b/src/python/library/tritonclient/hl/model_config/triton_model_config.py new file mode 100644 index 000000000..87aa276c3 --- /dev/null +++ b/src/python/library/tritonclient/hl/model_config/triton_model_config.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ModelConfig related objects.""" + +import dataclasses +from typing import Dict, Optional, Sequence, Type, Union + +import numpy as np + +from .common import DeviceKind, DynamicBatcher + + +@dataclasses.dataclass +class ResponseCache: + """Model response cache configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1765 + """ + + enable: bool + + +@dataclasses.dataclass +class TensorSpec: + """Stores specification of single tensor. This includes name, shape and dtype.""" + + name: str + shape: tuple + dtype: Union[Type[np.dtype], Type[object]] + optional: Optional[bool] = False + + +@dataclasses.dataclass +class TritonModelConfig: + """Triton Model Config dataclass for simplification and specialization of protobuf config generation. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto + """ + + model_name: str + model_version: int = 1 + max_batch_size: int = 4 + batching: bool = True + batcher: Optional[DynamicBatcher] = None + instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {}) + decoupled: bool = False + backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) + inputs: Optional[Sequence[TensorSpec]] = None + outputs: Optional[Sequence[TensorSpec]] = None + response_cache: Optional[ResponseCache] = None + + @property + def backend(self) -> str: + """Return backend parameter.""" + return "python" diff --git a/src/python/library/tritonclient/hl/triton_model_config.py b/src/python/library/tritonclient/hl/triton_model_config.py new file mode 100644 index 000000000..8949a409f --- /dev/null +++ b/src/python/library/tritonclient/hl/triton_model_config.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ModelConfig related objects.""" + +import dataclasses +from typing import Dict, Optional, Sequence, Type, Union + +import numpy as np + +from tritonclient.hl.common import DeviceKind, DynamicBatcher + + +@dataclasses.dataclass +class ResponseCache: + """Model response cache configuration. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1765 + """ + + enable: bool + + +@dataclasses.dataclass +class TensorSpec: + """Stores specification of single tensor. This includes name, shape and dtype.""" + + name: str + shape: tuple + dtype: Union[Type[np.dtype], Type[object]] + optional: Optional[bool] = False + + +@dataclasses.dataclass +class TritonModelConfig: + """Triton Model Config dataclass for simplification and specialization of protobuf config generation. + + More in Triton Inference Server [documentation] + [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto + """ + + model_name: str + model_version: int = 1 + max_batch_size: int = 4 + batching: bool = True + batcher: Optional[DynamicBatcher] = None + instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {}) + decoupled: bool = False + backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) + inputs: Optional[Sequence[TensorSpec]] = None + outputs: Optional[Sequence[TensorSpec]] = None + response_cache: Optional[ResponseCache] = None + + @property + def backend(self) -> str: + """Return backend parameter.""" + return "python" diff --git a/src/python/library/tritonclient/hl/utils.py b/src/python/library/tritonclient/hl/utils.py new file mode 100644 index 000000000..7b4f98cb4 --- /dev/null +++ b/src/python/library/tritonclient/hl/utils.py @@ -0,0 +1,384 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility module supporting model clients.""" + +import dataclasses +import enum +import logging +import socket +import sys +import time +import urllib +import warnings +from typing import Optional, Any + +import tritonclient.hl.lw.grpc +import tritonclient.hl.lw.http +#import tritonclient.http.aio +from grpc import RpcError +from tritonclient.hl.lw.utils import InferenceServerException + +from tritonclient.hl.exceptions import PyTritonClientInvalidUrlError, PyTritonClientTimeoutError +from tritonclient.hl.warnings import NotSupportedTimeoutWarning +from tritonclient.hl.constants import DEFAULT_GRPC_PORT, DEFAULT_HTTP_PORT +from tritonclient.hl.model_config.parser import ModelConfigParser + +_LOGGER = logging.getLogger(__name__) + + +_DEFAULT_NETWORK_TIMEOUT_S = 60.0 # 1min +_DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S = 60.0 # 1min +_DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S = 300.0 # 5min + +LATEST_MODEL_VERSION = "" + + +# Special value for model_version argument. If model_version is None, the latest version of the model is returned. + + +class ModelState(enum.Enum): + """Describe model state in Triton. + + Attributes: + LOADING: Loading of model + UNLOADING: Unloading of model + UNAVAILABLE: Model is missing or could not be loaded + READY: Model is ready for inference + """ + + LOADING = "LOADING" + UNLOADING = "UNLOADING" + UNAVAILABLE = "UNAVAILABLE" + READY = "READY" + + +def parse_http_response(models): + """Parse model repository index response from Triton Inference Server for HTTP.""" + models_states = {} + _LOGGER.debug("Parsing model repository index entries:") + for model in models: + _LOGGER.debug(f" name={model.get('name')} version={model.get('version')} state={model.get('state')}") + if not model.get("version"): + continue + + model_state = ModelState(model["state"]) if model.get("state") else ModelState.LOADING + models_states[(model["name"], model["version"])] = model_state + + return models_states + + +def parse_grpc_response(models): + """Parse model repository index response from Triton Inference Server for GRCP.""" + models_states = {} + _LOGGER.debug("Parsing model repository index entries:") + for model in models: + _LOGGER.debug(f" name={model.name} version={model.version} state={model.state}") + if not model.version: + continue + + model_state = ModelState(model.state) if model.state else ModelState.LOADING + models_states[(model.name, model.version)] = model_state + + return models_states + + +def get_model_state( + client: Any, + model_name: str, + model_version: Optional[str] = None, +) -> ModelState: + """Obtains state of the model deployed in Triton Inference Server. + + Args: + client: Triton Inference Server client to use for communication + model_name: name of the model which state we're requesting. + model_version: + version of the model which state we're requesting. + If model_version is None state of latest model is returned. + The latest versions of the model are the numerically greatest version numbers. + + Returns: + Model state. _ModelState.UNAVAILABLE is returned in case if model with given name and version is not found. + + """ + repository_index = client.get_model_repository_index() + if isinstance(repository_index, list): + models_states = parse_http_response(models=repository_index) + else: + models_states = parse_grpc_response(models=repository_index.models) + + if model_version is None: + requested_model_states = { + version: state for (name, version), state in models_states.items() if name == model_name + } + if not requested_model_states: + return ModelState.UNAVAILABLE + else: + requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0])) + _latest_version, latest_version_state = requested_model_states[-1] + return latest_version_state + else: + state = models_states.get((model_name, model_version), ModelState.UNAVAILABLE) + return state + + +def get_model_config( + client: Any, + model_name: str, + model_version: Optional[str] = None, + timeout_s: Optional[float] = None, +): + """Obtain configuration of model deployed on the Triton Inference Server. + + Function waits for server readiness. + + Typical use: + + client = tritonclient.grpc.Client("localhost:8001") + model_config = get_model_config(client, "MyModel", "1", 60.0) + model_config = get_model_config(client, "MyModel") + + Args: + client: Triton Inference Server client to use for communication + model_name: name of the model which configuration we're requesting. + model_version: + version of the model which configuration we're requesting. + If model_version is None configuration of the latest model is returned. + The latest versions of the model are the numerically greatest version numbers. + timeout_s: timeout to finish model configuration obtain. Default value is 300.0 s. + + Returns: + Configuration of requested model. + + Raises: + PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout. + PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. + """ + wait_for_model_ready(client, model_name=model_name, model_version=model_version, timeout_s=timeout_s) + + model_version = model_version or "" + + _LOGGER.debug(f"Obtaining model {model_name} config") + if isinstance(client, tritonclient.hl.lw.grpc.InferenceServerClient): + response = client.get_model_config(model_name, model_version, as_json=True) + model_config = response["config"] + else: + model_config = client.get_model_config(model_name, model_version) + model_config = ModelConfigParser.from_dict(model_config) + _LOGGER.debug(f"Model config: {model_config}") + return model_config + + +def _warn_on_too_big_network_timeout(client: Any, timeout_s: float): + pass + # if isinstance(client, tritonclient.hl.lw.http.InferenceServerClient): + # connection_pool = client._client_stub._connection_pool + # network_reldiff_s = (connection_pool.network_timeout - timeout_s) / timeout_s + # connection_reldiff_s = (connection_pool.connection_timeout - timeout_s) / timeout_s + # rtol = 0.001 + # if network_reldiff_s > rtol or connection_reldiff_s > rtol: + # warnings.warn( + # "Client network and/or connection timeout is smaller than requested timeout_s. This may cause unexpected behavior. " + # f"network_timeout={connection_pool.network_timeout} " + # f"connection_timeout={connection_pool.connection_timeout} " + # f"timeout_s={timeout_s}", + # NotSupportedTimeoutWarning, + # stacklevel=1, + # ) + + +def wait_for_server_ready( + client: Any, + timeout_s: Optional[float] = None, +): + """Waits for Triton Inference Server to be ready. + + Typical use: + + client = tritonclient.http.Client("localhost:8001") + wait_for_server_ready(client, timeout_s=600.0) + + Args: + client: Triton Inference Server client to use for communication + timeout_s: timeout to server get into readiness state. Default value is 60.0 s. + + Raises: + PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout. + """ + timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S + should_finish_before_s = time.time() + timeout_s + _warn_on_too_big_network_timeout(client, timeout_s) + + def _is_server_ready(): + try: + return client.is_server_ready() and client.is_server_live() + except InferenceServerException: + return False + except (RpcError, ConnectionError, socket.gaierror): # GRPC and HTTP clients raises these errors + return False + except Exception as e: + _LOGGER.exception(f"Exception while checking server readiness: {e}") + raise e + + timeout_s = max(0.0, should_finish_before_s - time.time()) + _LOGGER.debug(f"Waiting for server to be ready (timeout={timeout_s})") + is_server_ready = _is_server_ready() + while not is_server_ready: + time.sleep(min(1.0, timeout_s)) + is_server_ready = _is_server_ready() + if not is_server_ready and time.time() >= should_finish_before_s: + raise PyTritonClientTimeoutError("Waiting for server to be ready timed out.") + + +def wait_for_model_ready( + client: Any, + model_name: str, + model_version: Optional[str] = None, + timeout_s: Optional[float] = None, +): + """Wait for Triton Inference Server to be ready. + + Args: + client: Triton Inference Server client to use for communication. + model_name: name of the model to wait for readiness. + model_version: + version of the model to wait for readiness. + If model_version is None waiting for latest version of the model. + The latest versions of the model are the numerically greatest version numbers. + timeout_s: timeout to server and model get into readiness state. Default value is 300.0 s. + + Raises: + PyTritonClientTimeoutError: If server readiness didn't finish before given timeout. + """ + model_version = model_version or "" + model_version_msg = model_version or LATEST_MODEL_VERSION + timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S + should_finish_before_s = time.time() + timeout_s + + wait_for_server_ready(client, timeout_s=timeout_s) + timeout_s = max(0.0, should_finish_before_s - time.time()) + _LOGGER.debug(f"Waiting for model {model_name}/{model_version_msg} to be ready (timeout={timeout_s})") + is_model_ready = client.is_model_ready(model_name, model_version) + while not is_model_ready: + time.sleep(min(1.0, timeout_s)) + is_model_ready = client.is_model_ready(model_name, model_version) + + if not is_model_ready and time.time() >= should_finish_before_s: + raise PyTritonClientTimeoutError( + f"Waiting for model {model_name}/{model_version_msg} to be ready timed out." + ) + + +def create_client_from_url(url: str, network_timeout_s: Optional[float] = None) -> Any: # type: ignore + """Create Triton Inference Server client. + + Args: + url: url of the server to connect to. + If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added. + If url doesn't contain port (e.g. "localhost") default port for given scheme is added. + network_timeout_s: timeout for client commands. Default value is 60.0 s. + + Returns: + Triton Inference Server client. + + Raises: + PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. + """ + url = TritonUrl.from_url(url) + triton_client_lib = {"grpc": tritonclient.grpc, "http": tritonclient.http}[url.scheme] + + if url.scheme == "grpc": + # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout + network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s + warnings.warn( + f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.", + NotSupportedTimeoutWarning, + stacklevel=1, + ) + + triton_client_init_kwargs = {} + if network_timeout_s is not None: + triton_client_init_kwargs.update( + **{ + "grpc": {}, + "http": {"connection_timeout": network_timeout_s, "network_timeout": network_timeout_s}, + }[url.scheme] + ) + + _LOGGER.debug(f"Creating InferenceServerClient for {url.with_scheme} with {triton_client_init_kwargs}") + return triton_client_lib.InferenceServerClient(url.without_scheme, **triton_client_init_kwargs) + + +@dataclasses.dataclass +class TritonUrl: + """TritonUrl class for parsing Triton Inference Server url. + + Attributes: + scheme: scheme of the url (http or grpc) + hostname: hostname of the url + port: port of the url + + Examples: + triton_url = TritonUrl.from_url("localhost:8000") + triton_url.with_scheme + >>> "http://localhost:8000" + triton_url.without_scheme + >>> "localhost:8000" + triton_url.scheme, triton_url.hostname, triton_url.port + >>> ("http", "localhost", 8000) + """ + + scheme: str + hostname: str + port: int + + @classmethod + def from_url(cls, url): + """Parse triton url and create TritonUrl instance. + + Returns: + TritonUrl object with scheme, hostname and port. + """ + if not isinstance(url, str): + raise PyTritonClientInvalidUrlError(f"Invalid url {url}. Url must be a string.") + try: + parsed_url = urllib.parse.urlparse(url) + # change in py3.9+ + # https://github.com/python/cpython/commit/5a88d50ff013a64fbdb25b877c87644a9034c969 + if sys.version_info < (3, 9) and not parsed_url.scheme and "://" in parsed_url.path: + raise ValueError(f"Invalid url {url}. Only grpc and http are supported.") + if (not parsed_url.scheme and "://" not in parsed_url.path) or ( + sys.version_info >= (3, 9) and parsed_url.scheme and not parsed_url.netloc + ): + _LOGGER.debug(f"Adding http scheme to {url}") + parsed_url = urllib.parse.urlparse(f"http://{url}") + + scheme = parsed_url.scheme.lower() + if scheme not in ["grpc", "http"]: + raise ValueError(f"Invalid scheme {scheme}. Only grpc and http are supported.") + + port = parsed_url.port or {"grpc": DEFAULT_GRPC_PORT, "http": DEFAULT_HTTP_PORT}[scheme] + except ValueError as e: + raise PyTritonClientInvalidUrlError(f"Invalid url {url}") from e + return cls(scheme, parsed_url.hostname, port) + + @property + def with_scheme(self): + """Get Triton Inference Server url with scheme.""" + return f"{self.scheme}://{self.hostname}:{self.port}" + + @property + def without_scheme(self): + """Get Triton Inference Server url without scheme.""" + return f"{self.hostname}:{self.port}" diff --git a/src/python/library/tritonclient/hl/warnings.py b/src/python/library/tritonclient/hl/warnings.py new file mode 100644 index 000000000..7e121689e --- /dev/null +++ b/src/python/library/tritonclient/hl/warnings.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Warnings for pytriton module.""" + + +class PyTritonWarning(UserWarning): + """Base warning for pytriton module.""" + + pass + + +class NotSupportedTimeoutWarning(PyTritonWarning): + """A warning for client, which doesn't support timeout configuration.""" + + pass From 414e877f230c0b3f1203f7e5708aea216ae19240 Mon Sep 17 00:00:00 2001 From: Piotr Marcinkiewicz Date: Tue, 11 Jun 2024 23:34:48 +0200 Subject: [PATCH 2/4] Decoupled model PoC with gemerate_stream endpoint --- src/python/library/tritonclient/_client.py | 17 +- src/python/library/tritonclient/hl/README.md | 117 ++++++++++- .../library/tritonclient/hl/__init__.py | 1 + src/python/library/tritonclient/hl/client.py | 198 ++++++++++++++++-- src/python/library/tritonclient/hl/lw/grpc.py | 8 +- src/python/library/tritonclient/hl/lw/http.py | 172 +++++++++++---- .../library/tritonclient/hl/lw/infer_input.py | 2 + .../tritonclient/hl/model_config/parser.py | 24 ++- 8 files changed, 470 insertions(+), 69 deletions(-) diff --git a/src/python/library/tritonclient/_client.py b/src/python/library/tritonclient/_client.py index 2b5ba612f..78a1ffff1 100755 --- a/src/python/library/tritonclient/_client.py +++ b/src/python/library/tritonclient/_client.py @@ -27,7 +27,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from tritonclient.utils import raise_error -from tritonclient.hl import ModelClient +from tritonclient.hl import ModelClient, DecoupledModelClient +from typing import Union class InferenceServerClientBase: def __init__(self): @@ -99,13 +100,21 @@ def unregister_plugin(self): # # client.close() - + class Client: def __init__(self, url: str) -> None: self._client_url = url super().__init__() - def model(self, name: str) -> ModelClient: - return ModelClient(url=self._client_url, model_name=name) + def model(self, name: str) -> Union[ModelClient, DecoupledModelClient]: + client = ModelClient(url=self._client_url, model_name=name) + if client.model_config.decoupled: + try: + decoupled_client = DecoupledModelClient.from_existing_client(client) + finally: + client.close() + return decoupled_client + else: + return client diff --git a/src/python/library/tritonclient/hl/README.md b/src/python/library/tritonclient/hl/README.md index a5664f5cb..74084678a 100644 --- a/src/python/library/tritonclient/hl/README.md +++ b/src/python/library/tritonclient/hl/README.md @@ -1,10 +1,14 @@ +## Dependencies + Just for test install PyTriton client: ```bash pip install nvidia-pytriton ``` +## Non-decoupled PyTriton client + It is possible to test new client using PyTriton server: ```python @@ -33,16 +37,127 @@ triton.run() You can test new client with simple request: + ```python import numpy as np from tritonclient._client import Client -Client("localhost:8000").model("identity").infer(inputs={"input": np.char.encode([["a"]], "utf-8")} +client = Client("localhost:8000").model("identity") + +result = client.infer(inputs={"input": np.char.encode([["a"]], "utf-8")}) +``` + + + + Expected output: + ```python {'output': array(['a'], dtype=' +```python +triton = Triton() +triton.bind( + model_name="decoupled_identity", + infer_func=_infer_fn, + inputs=[ + Tensor(name="input", dtype=np.int32, shape=(-1,)), + # Shape with a batch dimension (-1) to support variable-sized batches. + ], + outputs=[ + Tensor(name="output", dtype=np.int32, shape=(-1,)), + # Output shape with a batch dimension (-1). + ], + config=ModelConfig(decoupled=True), +) +``` + +Start Triton: + + +```python +triton.run() +``` + + + + +User client for itegration over decoupled results: + + +```python +import numpy as np +from tritonclient._client import Client + +client = Client("localhost:8000").model("decoupled_identity") + +results = [] + +# Test fails with 500 error + +for result in client.infer(inputs={"input": np.array([1], dtype=np.int32)}): + print(result) + results.append(result) +``` + + + + +Expected output: + + +```python +{'output': array([0.001])} +{'output': array([0.001])} +{'output': array([0.001])} +``` \ No newline at end of file diff --git a/src/python/library/tritonclient/hl/__init__.py b/src/python/library/tritonclient/hl/__init__.py index 4ccc42221..0068e76d8 100644 --- a/src/python/library/tritonclient/hl/__init__.py +++ b/src/python/library/tritonclient/hl/__init__.py @@ -15,4 +15,5 @@ from .client import ( ModelClient, # noqa: F401 + DecoupledModelClient, # noqa: F401 ) diff --git a/src/python/library/tritonclient/hl/client.py b/src/python/library/tritonclient/hl/client.py index 16dfea44b..7de41b0f3 100755 --- a/src/python/library/tritonclient/hl/client.py +++ b/src/python/library/tritonclient/hl/client.py @@ -656,8 +656,8 @@ def _create_request(self, inputs: _IOType): return inputs_wrapped, outputs_wrapped def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: - import tritonclient.http - import tritonclient.utils + #import tritonclient.http + #import tritonclient.utils if self.model_config.decoupled: raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.") @@ -694,21 +694,23 @@ def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: _LOGGER.error(message) raise PyTritonClientTimeoutError(message) from e - if isinstance(response, tritonclient.http.InferResult): - outputs = { - output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"] - } - else: - outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs} + ## FIXME: Why it is necessary + #if isinstance(response, tritonclient.http.InferResult): + # outputs = { + # output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"] + # } + #else: + outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs} return outputs def _get_numpy_result(self, result): - import tritonclient.grpc - if isinstance(result, tritonclient.grpc.InferResult): - result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs} - else: - result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]} + # FIXME: Investigate if still it works for coupled model + #import tritonclient.hl.lw.grpc + #if isinstance(result, tritonclient.hl.lw.grpc.InferResult): + result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs} + #else: + # result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]} return result def _debatch_result(self, result): @@ -752,6 +754,176 @@ def _monkey_patched_del(self): self._triton_client_lib.InferenceServerClient.__del__ = _monkey_patched_del +class DecoupledModelClient(ModelClient): + """Synchronous client for decoupled model deployed on the Triton Inference Server.""" + + def __init__( + self, + url: str, + model_name: str, + model_version: Optional[str] = None, + *, + lazy_init: bool = True, + init_timeout_s: Optional[float] = None, + inference_timeout_s: Optional[float] = None, + model_config: Optional[TritonModelConfig] = None, + ensure_model_is_ready: bool = True, + ): + """Inits DecoupledModelClient for given decoupled model deployed on the Triton Inference Server. + + Common usage: + + ```python + client = DecoupledModelClient("localhost", "BERT") + for response in client.infer_sample(input1_sample, input2_sample): + print(response) + client.close() + ``` + + Args: + url: The Triton Inference Server url, e.g. `grpc://localhost:8001`. + In case no scheme is provided http scheme will be used as default. + In case no port is provided default port for given scheme will be used - + 8001 for grpc scheme, 8000 for http scheme. + model_name: name of the model to interact with. + model_version: version of the model to interact with. + If model_version is None inference on latest model will be performed. + The latest versions of the model are numerically the greatest version numbers. + lazy_init: if initialization should be performed just before sending first request to inference server. + init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used. + inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used. + model_config: model configuration. If not passed, it will be read from inference server during initialization. + ensure_model_is_ready: if model should be checked if it is ready before first inference request. + + Raises: + PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. + PyTritonClientTimeoutError: + if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. + PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. + """ + super().__init__( + url, + model_name, + model_version, + lazy_init=lazy_init, + init_timeout_s=init_timeout_s, + inference_timeout_s=inference_timeout_s, + model_config=model_config, + ensure_model_is_ready=ensure_model_is_ready, + ) + # Let's use generate endpoints for PoC + #if self._triton_url.scheme == "http": + # raise PyTritonClientValueError("DecoupledModelClient is only supported for grpc protocol") + self._queue = Queue() + self._lock = Lock() + + def close(self): + """Close resources used by DecoupledModelClient.""" + _LOGGER.debug("Closing DecoupledModelClient") + if self._lock.acquire(blocking=False): + try: + super().close() + finally: + self._lock.release() + else: + _LOGGER.warning("DecoupledModelClient is stil streaming answers") + self._infer_client.stop_stream(False) + super().close() + + def _infer(self, inputs: _IOType, parameters, headers): + if not self._lock.acquire(blocking=False): + raise PyTritonClientInferenceServerError("Inference is already in progress") + if not self.model_config.decoupled: + raise PyTritonClientInferenceServerError("Model config is coupled. Use ModelClient instead.") + + inputs_wrapped, outputs_wrapped = self._create_request(inputs) + if parameters is not None: + raise PyTritonClientValueError("DecoupledModelClient does not support parameters") + if headers is not None: + raise PyTritonClientValueError("DecoupledModelClient does not support headers") + try: + _LOGGER.debug("Sending inference request to Triton Inference Server") + if self._infer_client._stream is None: + self._infer_client.start_stream(callback=lambda result, error: self._response_callback(result, error)) + + self._infer_client.async_stream_infer( + model_name=self._model_name, + model_version=self._model_version or "", + inputs=inputs_wrapped, + outputs=outputs_wrapped, + request_id=self._next_request_id, + enable_empty_final_response=True, + **self._get_infer_extra_args(), + ) + except tritonclient.utils.InferenceServerException as e: + # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout + if "Deadline Exceeded" in e.message(): + raise PyTritonClientTimeoutError( + f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" + ) from e + + raise PyTritonClientInferenceServerError( + f"Error occurred during inference request. Message: {e.message()}" + ) from e + except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout + message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" + _LOGGER.error(message) + raise PyTritonClientTimeoutError(message) from e + except OSError as e: # tritonclient.http raises socket.error for connection error + message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" + _LOGGER.error(message) + raise PyTritonClientTimeoutError(message) from e + _LOGGER.debug("Returning response iterator") + return self._create_response_iterator() + + def _response_callback(self, response, error): + _LOGGER.debug(f"Received response from Triton Inference Server: {response}") + if error: + _LOGGER.error(f"Error occurred during inference request. Message: {error}") + self._queue.put(error) + else: + actual_response = response.get_response() + # Check if the object is not None + triton_final_response = actual_response.parameters.get("triton_final_response") + if triton_final_response and triton_final_response.bool_param: + self._queue.put(None) + else: + result = self._get_numpy_result(response) + self._queue.put(result) + + def _create_response_iterator(self): + try: + while True: + try: + item = self._queue.get(self._inference_timeout_s) + except Empty as e: + message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s" + _LOGGER.error(message) + raise PyTritonClientTimeoutError(message) from e + if isinstance(item, Exception): + if hasattr(item, "message"): + message = f"Error occurred during inference request. Message: {item.message()}" + else: + message = f"Error occurred during inference request. Message: {item}" + _LOGGER.error(message) + raise PyTritonClientInferenceServerError(message) from item + + if item is None: + break + yield item + finally: + self._lock.release() + + def _debatch_result(self, result): + if self.is_batching_supported: + result = ({name: data[0] for name, data in result_.items()} for result_ in result) + return result + + def _get_infer_extra_args(self): + # kwargs = super()._get_infer_extra_args() + kwargs = {} + # kwargs["enable_empty_final_response"] = True + return kwargs class InferenceServerClientBase: def __init__(self): diff --git a/src/python/library/tritonclient/hl/lw/grpc.py b/src/python/library/tritonclient/hl/lw/grpc.py index 559d58f14..10448326e 100755 --- a/src/python/library/tritonclient/hl/lw/grpc.py +++ b/src/python/library/tritonclient/hl/lw/grpc.py @@ -17,4 +17,10 @@ class InferenceServerClient: """Client to perform grpc communication with the Triton. """ - pass \ No newline at end of file + pass + +class InferResult: + """Result from the inference server.""" + pass + +from tritonclient.hl.lw.http import InferResponse diff --git a/src/python/library/tritonclient/hl/lw/http.py b/src/python/library/tritonclient/hl/lw/http.py index 92905d6dd..542cee74b 100755 --- a/src/python/library/tritonclient/hl/lw/http.py +++ b/src/python/library/tritonclient/hl/lw/http.py @@ -14,9 +14,14 @@ import requests import numpy as np +import threading +import json +import logging from tritonclient.hl.lw.infer_input import InferInput +_LOGGER = logging.getLogger(__name__) + class InferOutput(): """Output from the inference server.""" def __init__(self, name): @@ -27,10 +32,29 @@ def __init__(self, name): class InferResponse(): """Response from the inference server.""" - def __init__(self, outputs): + class Parameter: + """Parameter for the response.""" + def __init__(self, parameter): + """Initialize the parameter.""" + self.bool_param = parameter + + class Parameters: + """Parameters for the response.""" + def __init__(self, parameters): + """Initialize the parameters.""" + self.parameters = parameters + + def get(self, key): + """Get the key.""" + return InferResponse.Parameter(self.parameters.get(key)) + + def __init__(self, outputs=None, parameters=None): """Initialize the response.""" + if outputs is None: + outputs = {"outputs": []} self._rest_outputs = outputs self.outputs = [InferOutput(response['name']) for response in outputs['outputs']] + self.parameters = InferResponse.Parameters(parameters) def get_response(self): """Get the response.""" @@ -43,12 +67,50 @@ def as_numpy(self, name): return np.array(response['data']) return None +def _array_to_first_element(arr): + """ + Convert a NumPy array to its first element if it contains only one element. + Raise a ValueError if the array contains more than one element. + + Parameters: + arr (np.ndarray): The input NumPy array. + + Returns: + The first element of the array if it contains only one element. + + Raises: + ValueError: If the array contains more than one element. + """ + if arr.size == 1: + return arr.item() + else: + raise ValueError("Array contains more than one element") + +def _process_chunk(chunk): + """Process the chunk of data received.""" + # Decode the byte string to a regular string + chunk_str = chunk.decode('utf-8') + + # Strip the "data: " prefix + if chunk_str.startswith('data: '): + chunk_str = chunk_str[len('data: '):] + + # Load the JSON string into a Python dictionary + chunk_json = json.loads(chunk_str) + return chunk_json + + class InferenceServerClient(): """Client to perform http communication with the Triton. """ def __init__(self, url, **kwargs): self.url = url + self._stream = None + self._callback = None + self._event = threading.Event() + self._exception = None + def is_server_ready(self): """Check if the server is ready. @@ -80,59 +142,91 @@ def get_model_config(self, model_name, model_version): return response.json() - # In [13]: import requests - # ...: import json - # ...: - # ...: # Define the server URL - # ...: server_url = "http://localhost:8000/v2/models/identity/versions/1/infer" - # ...: - # ...: # Prepare the input data - # ...: input_string = "Hello Triton Inference Server!" - # ...: - # ...: # Triton requires the data to be in a specific format - # ...: inputs = [ - # ...: { - # ...: "name": "input", - # ...: "shape": [1, 1], # Adjust the shape to include the batch dimension - # ...: "datatype": "BYTES", - # ...: "data": [input_string] - # ...: } - # ...: ] - # ...: - # ...: # Prepare the request payload - # ...: payload = { - # ...: "inputs": inputs - # ...: } - # ...: - # ...: # Send the request - # ...: response = requests.post(server_url, json=payload) - # ...: - # ...: # Check the response status - # ...: if response.status_code == 200: - # ...: result = response.json() - # ...: print("Inference result:", result) - # ...: else: - # ...: print("Failed to get inference result:", response.status_code, response.text) - # ...: - # Inference result: {'model_name': 'identity', 'model_version': '1', 'outputs': [{'name': 'output', 'datatype': 'BYTES', 'shape': [1, 1], 'data': ['Hello Triton Inference Server!']}]} - # In [14]: def infer(self, model_name, model_version, inputs, headers, outputs, request_id, parameters): """Perform inference. """ model_version = model_version if model_version else "1" request = "http://" + self.url + "/v2/models/{}/versions/{}/infer".format(model_name, model_version) - print(request) inputs_for_json = [input_value.to_dict() for input_value in inputs] - print(inputs_for_json) ## TODO: Support setting outputs, request_id, parameters and headers response = requests.post(request, json={"inputs": inputs_for_json}) return InferResponse(response.json()) + def start_stream(self, callback): + """Start the stream. + """ + self._callback = callback + + + def process_chunk(self, chunk, error): + """Process the chunk of data received.""" + if self._callback: + self._callback(chunk, error) + + def stream_request(self, model_name, model_version, inputs_for_stream): + request_url = f"http://{self.url}/v2/models/{model_name}/generate_stream" + headers = {"Content-Type": "application/json"} + + try: + with requests.post(request_url, json=inputs_for_stream, headers=headers, stream=True) as response: + if response.status_code != 200: + _LOGGER.debug(f"Request failed with status code: {response}") + self._exception = Exception(f"Request failed with status code: {response.status_code}") + self._event.set() + return + + self._event.set() # Signal that the first response was received successfully + + try: + for line in response.iter_lines(): + if line: # Filter out keep-alive new lines + response_json = _process_chunk(line) + outputs = [] + for key, value in response_json.items(): + if key not in ["model_name", "model_version"]: + outputs.append({"name": key, "data": [value]}) + outputs_struct = {"outputs": outputs} + response = InferResponse(outputs_struct, parameters={"triton_final_response": False}) + self.process_chunk(response, None) + except Exception as e: + _LOGGER.debug(f"Some error occurred while processing the response: {e}") + self.process_chunk(None, e) + response = InferResponse(outputs=None, parameters={"triton_final_response": True}) + self.process_chunk(response, None) + except Exception as e: + _LOGGER.debug(f"Some error occurred while processing the response: {e}") + self._exception = e + self._event.set() + + def async_stream_infer(self, model_name, model_version, inputs, outputs, request_id, enable_empty_final_response, **kwargs): + """Perform inference.""" + model_version = model_version if model_version else "1" + inputs_for_stream = {input_value.name(): _array_to_first_element(input_value._np_data) for input_value in inputs} + + self._event.clear() + self._exception = None + self._stream_thread = threading.Thread(target=self.stream_request, args=(model_name, model_version, inputs_for_stream)) + self._stream_thread.start() + self._event.wait() # Block until the first 200 response or error is returned + + if self._exception: + raise self._exception + + def close(self): + """Close the stream and join the thread.""" + if self._stream is not None: + self._stream.join() + + def close(self): + """Close the client. + """ + pass + class InferRequestedOutput(): def __init__(self, name): diff --git a/src/python/library/tritonclient/hl/lw/infer_input.py b/src/python/library/tritonclient/hl/lw/infer_input.py index 076458bb3..2c04d4d37 100755 --- a/src/python/library/tritonclient/hl/lw/infer_input.py +++ b/src/python/library/tritonclient/hl/lw/infer_input.py @@ -222,6 +222,7 @@ def set_data_from_numpy(self, input_tensor): InferenceServerException If failed to set data for the tensor. """ + self._np_data = input_tensor binary_data=False @@ -349,6 +350,7 @@ def to_dict(self): "datatype": self.datatype(), "data": self._get_tensor()["data"], } + class InferRequestedOutput(): def __init__(self, name): diff --git a/src/python/library/tritonclient/hl/model_config/parser.py b/src/python/library/tritonclient/hl/model_config/parser.py index 83d74cb9f..d5e669734 100644 --- a/src/python/library/tritonclient/hl/model_config/parser.py +++ b/src/python/library/tritonclient/hl/model_config/parser.py @@ -34,21 +34,23 @@ class ModelConfig. import numpy as np from google.protobuf import json_format, text_format # pytype: disable=pyi-error +from tritonclient.hl.lw.infer_input import triton_to_np_dtype + #from pytriton.exceptions import PyTritonModelConfigError from .common import QueuePolicy, TimeoutAction from .triton_model_config import DeviceKind, DynamicBatcher, ResponseCache, TensorSpec, TritonModelConfig -try: - import tritonclient.grpc as grpc_client - from tritonclient import utils as client_utils # noqa: F401 -except ImportError: - try: - import tritonclientutils as client_utils # noqa: F401 - import tritongrpcclient as grpc_client - except ImportError: - client_utils = None - grpc_client = None +# try: +# import tritonclient.hl.lw.grpc as grpc_client +# from tritonclient import utils as client_utils # noqa: F401 +# except ImportError: +# try: +# import tritonclientutils as client_utils # noqa: F401 +# import tritongrpcclient as grpc_client +# except ImportError: +# client_utils = None +# grpc_client = None LOGGER = logging.getLogger(__name__) @@ -181,7 +183,7 @@ def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec: if data_type == "STRING": dtype = np.bytes_ else: - dtype = client_utils.triton_to_np_dtype(data_type) + dtype = triton_to_np_dtype(data_type) if dtype is None: raise PyTritonModelConfigError(f"Unsupported data type `{data_type}` for {io_type} with name `{name}`") From 991d7309b2225027df7c62168073886946d1d481 Mon Sep 17 00:00:00 2001 From: Piotr Marcinkiewicz Date: Mon, 17 Jun 2024 18:03:38 +0200 Subject: [PATCH 3/4] Fix pre-commit issues and attributes --- src/python/library/tritonclient/_client.py | 19 +- src/python/library/tritonclient/hl/README.md | 2 +- .../library/tritonclient/hl/__init__.py | 6 +- src/python/library/tritonclient/hl/client.py | 211 ++++++++++++------ .../library/tritonclient/hl/constants.py | 8 +- .../library/tritonclient/hl/lw/__init__.py | 0 src/python/library/tritonclient/hl/lw/grpc.py | 8 +- src/python/library/tritonclient/hl/lw/http.py | 157 ++++++++----- .../library/tritonclient/hl/lw/infer_input.py | 15 +- .../library/tritonclient/hl/lw/utils.py | 3 +- .../hl/model_config/model_config.py | 4 +- .../tritonclient/hl/model_config/parser.py | 124 +++++++--- .../tritonclient/hl/model_config/tensor.py | 4 +- .../hl/model_config/triton_model_config.py | 4 +- .../tritonclient/hl/triton_model_config.py | 5 +- src/python/library/tritonclient/hl/utils.py | 111 ++++++--- 16 files changed, 466 insertions(+), 215 deletions(-) mode change 100755 => 100644 src/python/library/tritonclient/hl/client.py mode change 100755 => 100644 src/python/library/tritonclient/hl/lw/__init__.py mode change 100755 => 100644 src/python/library/tritonclient/hl/lw/grpc.py mode change 100755 => 100644 src/python/library/tritonclient/hl/lw/http.py mode change 100755 => 100644 src/python/library/tritonclient/hl/lw/infer_input.py mode change 100755 => 100644 src/python/library/tritonclient/hl/lw/utils.py diff --git a/src/python/library/tritonclient/_client.py b/src/python/library/tritonclient/_client.py index 78a1ffff1..d6d189095 100755 --- a/src/python/library/tritonclient/_client.py +++ b/src/python/library/tritonclient/_client.py @@ -25,10 +25,11 @@ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import Union + +from tritonclient.hl import DecoupledModelClient, ModelClient from tritonclient.utils import raise_error -from tritonclient.hl import ModelClient, DecoupledModelClient -from typing import Union class InferenceServerClientBase: def __init__(self): @@ -89,24 +90,23 @@ def unregister_plugin(self): # # Change url to 'http://localhost:8000' for utilizing HTTP client # client = Client(url='grpc://loacalhost:8001') -# +# # input_tensor_as_numpy = np.array(...) -# -# # Infer should be async similar to the exising Python APIs +# +# # Infer should be async similar to the existing Python APIs # responses = client.model('simple').infer(inputs={'input': input_tensor_as_numpy}) -# +# # for response in responses: # numpy_array = np.asarray(response.outputs['output']) -# +# # client.close() - class Client: def __init__(self, url: str) -> None: self._client_url = url super().__init__() - + def model(self, name: str) -> Union[ModelClient, DecoupledModelClient]: client = ModelClient(url=self._client_url, model_name=name) if client.model_config.decoupled: @@ -117,4 +117,3 @@ def model(self, name: str) -> Union[ModelClient, DecoupledModelClient]: return decoupled_client else: return client - diff --git a/src/python/library/tritonclient/hl/README.md b/src/python/library/tritonclient/hl/README.md index 74084678a..6297eb460 100644 --- a/src/python/library/tritonclient/hl/README.md +++ b/src/python/library/tritonclient/hl/README.md @@ -160,4 +160,4 @@ Expected output: {'output': array([0.001])} {'output': array([0.001])} {'output': array([0.001])} -``` \ No newline at end of file +``` \ No newline at end of file diff --git a/src/python/library/tritonclient/hl/__init__.py b/src/python/library/tritonclient/hl/__init__.py index 0068e76d8..e78e2d1c4 100644 --- a/src/python/library/tritonclient/hl/__init__.py +++ b/src/python/library/tritonclient/hl/__init__.py @@ -13,7 +13,5 @@ # limitations under the License. # noqa: D104 -from .client import ( - ModelClient, # noqa: F401 - DecoupledModelClient, # noqa: F401 -) +from .client import DecoupledModelClient # noqa: F401 +from .client import ModelClient # noqa: F401 diff --git a/src/python/library/tritonclient/hl/client.py b/src/python/library/tritonclient/hl/client.py old mode 100755 new mode 100644 index 7de41b0f3..4e8112b8c --- a/src/python/library/tritonclient/hl/client.py +++ b/src/python/library/tritonclient/hl/client.py @@ -44,20 +44,11 @@ from threading import Lock, Thread from typing import Any, Dict, Optional, Tuple, Union -#import gevent +# import gevent import numpy as np - -# Old client tritonclient imports -# import tritonclient.grpc -# import tritonclient.grpc.aio -# import tritonclient.http -# import tritonclient.http.aio -# import tritonclient.utils - import tritonclient.hl.lw.grpc import tritonclient.hl.lw.http import tritonclient.hl.lw.utils - from tritonclient.hl.exceptions import ( PyTritonClientClosedError, PyTritonClientInferenceServerError, @@ -66,6 +57,7 @@ PyTritonClientTimeoutError, PyTritonClientValueError, ) +from tritonclient.hl.triton_model_config import TritonModelConfig from tritonclient.hl.utils import ( _DEFAULT_NETWORK_TIMEOUT_S, _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S, @@ -75,7 +67,14 @@ wait_for_server_ready, ) from tritonclient.hl.warnings import NotSupportedTimeoutWarning -from tritonclient.hl.triton_model_config import TritonModelConfig + +# Old client tritonclient imports +# import tritonclient.grpc +# import tritonclient.grpc.aio +# import tritonclient.http +# import tritonclient.http.aio +# import tritonclient.utils + _LOGGER = logging.getLogger(__name__) @@ -91,10 +90,14 @@ def _verify_inputs_args(inputs, named_inputs): if not inputs and not named_inputs: raise PyTritonClientValueError("Provide input data") if not bool(inputs) ^ bool(named_inputs): - raise PyTritonClientValueError("Use either positional either keyword method arguments convention") + raise PyTritonClientValueError( + "Use either positional either keyword method arguments convention" + ) -def _verify_parameters(parameters_or_headers: Optional[Dict[str, Union[str, int, bool]]] = None): +def _verify_parameters( + parameters_or_headers: Optional[Dict[str, Union[str, int, bool]]] = None +): if parameters_or_headers is None: return if not isinstance(parameters_or_headers, dict): @@ -103,7 +106,9 @@ def _verify_parameters(parameters_or_headers: Optional[Dict[str, Union[str, int, if not isinstance(key, str): raise PyTritonClientValueError("Parameter/header key must be a string") if not isinstance(value, (str, int, bool)): - raise PyTritonClientValueError("Parameter/header value must be a string, integer or boolean") + raise PyTritonClientValueError( + "Parameter/header value must be a string, integer or boolean" + ) class BaseModelClient: @@ -152,19 +157,28 @@ def __init__( if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. """ - self._init_timeout_s = _DEFAULT_SYNC_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s - self._inference_timeout_s = DEFAULT_INFERENCE_TIMEOUT_S if inference_timeout_s is None else inference_timeout_s + self._init_timeout_s = ( + _DEFAULT_SYNC_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s + ) + self._inference_timeout_s = ( + DEFAULT_INFERENCE_TIMEOUT_S + if inference_timeout_s is None + else inference_timeout_s + ) self._network_timeout_s = min(_DEFAULT_NETWORK_TIMEOUT_S, self._init_timeout_s) - self._general_client = self.create_client_from_url(url, network_timeout_s=self._network_timeout_s) - self._infer_client = self.create_client_from_url(url, network_timeout_s=self._inference_timeout_s) + self._general_client = self.create_client_from_url( + url, network_timeout_s=self._network_timeout_s + ) + self._infer_client = self.create_client_from_url( + url, network_timeout_s=self._inference_timeout_s + ) self._model_name = model_name self._model_version = model_version self._request_id_generator = itertools.count(0) - if model_config is not None: self._model_config = model_config self._model_ready = None if ensure_model_is_ready else True @@ -208,7 +222,9 @@ def from_existing_client(cls, existing_client: "BaseModelClient"): return new_client - def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = None): + def create_client_from_url( + self, url: str, network_timeout_s: Optional[float] = None + ): """Create Triton Inference Server client. Args: @@ -229,7 +245,11 @@ def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = if self._triton_url.scheme == "grpc": # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout - network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s + network_timeout_s = ( + _DEFAULT_NETWORK_TIMEOUT_S + if network_timeout_s is None + else network_timeout_s + ) warnings.warn( f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.", NotSupportedTimeoutWarning, @@ -241,7 +261,9 @@ def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = _LOGGER.debug( f"Creating InferenceServerClient for {self._triton_url.with_scheme} with {triton_client_init_kwargs}" ) - return self._triton_client_lib.InferenceServerClient(self._url, **triton_client_init_kwargs) + return self._triton_client_lib.InferenceServerClient( + self._url, **triton_client_init_kwargs + ) def get_lib(self): """Returns tritonclient library for given scheme.""" @@ -381,7 +403,9 @@ def __init__( def get_lib(self): """Returns tritonclient library for given scheme.""" - return {"grpc": tritonclient.hl.lw.grpc, "http": tritonclient.hl.lw.http}[self._triton_url.scheme.lower()] + return {"grpc": tritonclient.hl.lw.grpc, "http": tritonclient.hl.lw.http}[ + self._triton_url.scheme.lower() + ] def __enter__(self): """Create context for using ModelClient as a context manager.""" @@ -443,7 +467,12 @@ def wait_for_model(self, timeout_s: float): """ if self._general_client is None: raise PyTritonClientClosedError("ModelClient is closed") - wait_for_model_ready(self._general_client, self._model_name, self._model_version, timeout_s=timeout_s) + wait_for_model_ready( + self._general_client, + self._model_name, + self._model_version, + timeout_s=timeout_s, + ) @property def is_batching_supported(self): @@ -485,7 +514,10 @@ def model_config(self) -> TritonModelConfig: raise PyTritonClientClosedError("ModelClient is closed") self._model_config = get_model_config( - self._general_client, self._model_name, self._model_version, timeout_s=self._init_timeout_s + self._general_client, + self._model_name, + self._model_version, + timeout_s=self._init_timeout_s, ) return self._model_config @@ -537,7 +569,9 @@ def infer_sample( if inputs: inputs = tuple(data[np.newaxis, ...] for data in inputs) elif named_inputs: - named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()} + named_inputs = { + name: data[np.newaxis, ...] for name, data in named_inputs.items() + } result = self._infer(inputs or named_inputs, parameters, headers) @@ -606,7 +640,6 @@ def infer(self, inputs): """Run synchronous batch inference using a single dictionary of inputs.""" return self.infer_batch(**inputs) - def _wait_and_init_model_config(self, init_timeout_s: float): if self._general_client is None: raise PyTritonClientClosedError("ModelClient is closed") @@ -616,7 +649,10 @@ def _wait_and_init_model_config(self, init_timeout_s: float): self._model_ready = True timeout_s = max(0.0, should_finish_before_s - time.time()) self._model_config = get_model_config( - self._general_client, self._model_name, self._model_version, timeout_s=timeout_s + self._general_client, + self._model_name, + self._model_version, + timeout_s=timeout_s, ) def _create_request(self, inputs: _IOType): @@ -627,7 +663,10 @@ def _create_request(self, inputs: _IOType): self._wait_and_init_model_config(self._init_timeout_s) if isinstance(inputs, Tuple): - inputs = {input_spec.name: input_data for input_spec, input_data in zip(self.model_config.inputs, inputs)} + inputs = { + input_spec.name: input_data + for input_spec, input_data in zip(self.model_config.inputs, inputs) + } inputs_wrapped = [] @@ -635,7 +674,9 @@ def _create_request(self, inputs: _IOType): inputs: Dict[str, np.ndarray] for input_name, input_data in inputs.items(): - if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes): + if input_data.dtype == object and not isinstance( + input_data.reshape(-1)[0], bytes + ): raise RuntimeError( f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \ \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}" @@ -646,20 +687,25 @@ def _create_request(self, inputs: _IOType): f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))." ) triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype) - infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype) + infer_input = self._triton_client_lib.InferInput( + input_name, input_data.shape, triton_dtype + ) infer_input.set_data_from_numpy(input_data) inputs_wrapped.append(infer_input) outputs_wrapped = [ - self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in self.model_config.outputs + self._triton_client_lib.InferRequestedOutput(output_spec.name) + for output_spec in self.model_config.outputs ] return inputs_wrapped, outputs_wrapped def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: - #import tritonclient.http - #import tritonclient.utils + # import tritonclient.http + # import tritonclient.utils if self.model_config.decoupled: - raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.") + raise PyTritonClientInferenceServerError( + "Model config is decoupled. Use DecoupledModelClient instead." + ) inputs_wrapped, outputs_wrapped = self._create_request(inputs) @@ -676,7 +722,7 @@ def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: **self._get_infer_extra_args(), ) except tritonclient.utils.InferenceServerException as e: - # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout + # tritonclient.grpc raises exception with message containing "Deadline Exceeded" for timeout if "Deadline Exceeded" in e.message(): raise PyTritonClientTimeoutError( f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" @@ -685,31 +731,41 @@ def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: raise PyTritonClientInferenceServerError( f"Error occurred during inference request. Message: {e.message()}" ) from e - except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout + except ( + socket.timeout + ) as e: # tritonclient.http raises socket.timeout for timeout message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" _LOGGER.error(message) raise PyTritonClientTimeoutError(message) from e - except OSError as e: # tritonclient.http raises socket.error for connection error + except ( + OSError + ) as e: # tritonclient.http raises socket.error for connection error message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" _LOGGER.error(message) raise PyTritonClientTimeoutError(message) from e ## FIXME: Why it is necessary - #if isinstance(response, tritonclient.http.InferResult): + # if isinstance(response, tritonclient.http.InferResult): # outputs = { # output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"] # } - #else: - outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs} + # else: + outputs = { + output.name: response.as_numpy(output.name) + for output in response.get_response().outputs + } return outputs def _get_numpy_result(self, result): # FIXME: Investigate if still it works for coupled model - #import tritonclient.hl.lw.grpc - #if isinstance(result, tritonclient.hl.lw.grpc.InferResult): - result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs} - #else: + # import tritonclient.hl.lw.grpc + # if isinstance(result, tritonclient.hl.lw.grpc.InferResult): + result = { + output.name: result.as_numpy(output.name) + for output in result.get_response().outputs + } + # else: # result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]} return result @@ -748,12 +804,15 @@ def _monkey_patched_del(self): try: old_del(self) except gevent.exceptions.InvalidThreadUseError: - _LOGGER.info("gevent.exceptions.InvalidThreadUseError in __del__ of InferenceServerClient") + _LOGGER.info( + "gevent.exceptions.InvalidThreadUseError in __del__ of InferenceServerClient" + ) except Exception as e: _LOGGER.error("Exception in __del__ of InferenceServerClient: %s", e) self._triton_client_lib.InferenceServerClient.__del__ = _monkey_patched_del + class DecoupledModelClient(ModelClient): """Synchronous client for decoupled model deployed on the Triton Inference Server.""" @@ -812,7 +871,7 @@ def __init__( ensure_model_is_ready=ensure_model_is_ready, ) # Let's use generate endpoints for PoC - #if self._triton_url.scheme == "http": + # if self._triton_url.scheme == "http": # raise PyTritonClientValueError("DecoupledModelClient is only supported for grpc protocol") self._queue = Queue() self._lock = Lock() @@ -834,17 +893,27 @@ def _infer(self, inputs: _IOType, parameters, headers): if not self._lock.acquire(blocking=False): raise PyTritonClientInferenceServerError("Inference is already in progress") if not self.model_config.decoupled: - raise PyTritonClientInferenceServerError("Model config is coupled. Use ModelClient instead.") + raise PyTritonClientInferenceServerError( + "Model config is coupled. Use ModelClient instead." + ) inputs_wrapped, outputs_wrapped = self._create_request(inputs) if parameters is not None: - raise PyTritonClientValueError("DecoupledModelClient does not support parameters") + raise PyTritonClientValueError( + "DecoupledModelClient does not support parameters" + ) if headers is not None: - raise PyTritonClientValueError("DecoupledModelClient does not support headers") + raise PyTritonClientValueError( + "DecoupledModelClient does not support headers" + ) try: _LOGGER.debug("Sending inference request to Triton Inference Server") if self._infer_client._stream is None: - self._infer_client.start_stream(callback=lambda result, error: self._response_callback(result, error)) + self._infer_client.start_stream( + callback=lambda result, error: self._response_callback( + result, error + ) + ) self._infer_client.async_stream_infer( model_name=self._model_name, @@ -856,7 +925,7 @@ def _infer(self, inputs: _IOType, parameters, headers): **self._get_infer_extra_args(), ) except tritonclient.utils.InferenceServerException as e: - # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout + # tritonclient.grpc raises exception with message containing "Deadline Exceeded" for timeout if "Deadline Exceeded" in e.message(): raise PyTritonClientTimeoutError( f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" @@ -865,11 +934,15 @@ def _infer(self, inputs: _IOType, parameters, headers): raise PyTritonClientInferenceServerError( f"Error occurred during inference request. Message: {e.message()}" ) from e - except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout + except ( + socket.timeout + ) as e: # tritonclient.http raises socket.timeout for timeout message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" _LOGGER.error(message) raise PyTritonClientTimeoutError(message) from e - except OSError as e: # tritonclient.http raises socket.error for connection error + except ( + OSError + ) as e: # tritonclient.http raises socket.error for connection error message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" _LOGGER.error(message) raise PyTritonClientTimeoutError(message) from e @@ -884,7 +957,9 @@ def _response_callback(self, response, error): else: actual_response = response.get_response() # Check if the object is not None - triton_final_response = actual_response.parameters.get("triton_final_response") + triton_final_response = actual_response.parameters.get( + "triton_final_response" + ) if triton_final_response and triton_final_response.bool_param: self._queue.put(None) else: @@ -904,7 +979,9 @@ def _create_response_iterator(self): if hasattr(item, "message"): message = f"Error occurred during inference request. Message: {item.message()}" else: - message = f"Error occurred during inference request. Message: {item}" + message = ( + f"Error occurred during inference request. Message: {item}" + ) _LOGGER.error(message) raise PyTritonClientInferenceServerError(message) from item @@ -916,7 +993,9 @@ def _create_response_iterator(self): def _debatch_result(self, result): if self.is_batching_supported: - result = ({name: data[0] for name, data in result_.items()} for result_ in result) + result = ( + {name: data[0] for name, data in result_.items()} for result_ in result + ) return result def _get_infer_extra_args(self): @@ -925,6 +1004,7 @@ def _get_infer_extra_args(self): # kwargs["enable_empty_final_response"] = True return kwargs + class InferenceServerClientBase: def __init__(self): self._plugin = None @@ -984,23 +1064,22 @@ def unregister_plugin(self): # # Change url to 'http://localhost:8000' for utilizing HTTP client # client = Client(url='grpc://loacalhost:8001') -# +# # input_tensor_as_numpy = np.array(...) -# -# # Infer should be async similar to the exising Python APIs +# +# # Infer should be async similar to the existing Python APIs # responses = client.model('simple').infer(inputs={'input': input_tensor_as_numpy}) -# +# # for response in responses: # numpy_array = np.asarray(response.outputs['output']) -# +# # client.close() - + class Client(InferenceServerClientBase): def __init__(self, url: str) -> None: self._client_url = url super().__init__() - + def model(self, name: str) -> ModelClient: - return ModelClient(url=self._client_url, model_name=name) - + return ModelClient(url=self._client_url, model_name=name) diff --git a/src/python/library/tritonclient/hl/constants.py b/src/python/library/tritonclient/hl/constants.py index 49f8723c8..0b25a3626 100644 --- a/src/python/library/tritonclient/hl/constants.py +++ b/src/python/library/tritonclient/hl/constants.py @@ -26,6 +26,10 @@ DEFAULT_TRITON_STARTUP_TIMEOUT_S = 120 CREATE_TRITON_CLIENT_TIMEOUT_S = 10 -__DEFAULT_PYTRITON_HOME = os.path.join(os.getenv("XDG_CACHE_HOME", "$HOME/.cache"), "pytriton") -__PYTRITON_HOME = os.path.expanduser(os.path.expandvars(os.getenv("PYTRITON_HOME", __DEFAULT_PYTRITON_HOME))) +__DEFAULT_PYTRITON_HOME = os.path.join( + os.getenv("XDG_CACHE_HOME", "$HOME/.cache"), "pytriton" +) +__PYTRITON_HOME = os.path.expanduser( + os.path.expandvars(os.getenv("PYTRITON_HOME", __DEFAULT_PYTRITON_HOME)) +) PYTRITON_HOME = pathlib.Path(__PYTRITON_HOME).resolve().absolute() diff --git a/src/python/library/tritonclient/hl/lw/__init__.py b/src/python/library/tritonclient/hl/lw/__init__.py old mode 100755 new mode 100644 diff --git a/src/python/library/tritonclient/hl/lw/grpc.py b/src/python/library/tritonclient/hl/lw/grpc.py old mode 100755 new mode 100644 index 10448326e..25c804d7e --- a/src/python/library/tritonclient/hl/lw/grpc.py +++ b/src/python/library/tritonclient/hl/lw/grpc.py @@ -14,13 +14,17 @@ """New Tritonclient GRPC not implemented yet.""" + class InferenceServerClient: - """Client to perform grpc communication with the Triton. - """ + """Client to perform grpc communication with the Triton.""" + pass + class InferResult: """Result from the inference server.""" + pass + from tritonclient.hl.lw.http import InferResponse diff --git a/src/python/library/tritonclient/hl/lw/http.py b/src/python/library/tritonclient/hl/lw/http.py old mode 100755 new mode 100644 index 542cee74b..08c46fbb1 --- a/src/python/library/tritonclient/hl/lw/http.py +++ b/src/python/library/tritonclient/hl/lw/http.py @@ -12,34 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import requests -import numpy as np -import threading import json import logging +import threading +import numpy as np +import requests from tritonclient.hl.lw.infer_input import InferInput _LOGGER = logging.getLogger(__name__) -class InferOutput(): + +class InferOutput: """Output from the inference server.""" + def __init__(self, name): """Initialize the output.""" self.name = name -class InferResponse(): +class InferResponse: """Response from the inference server.""" class Parameter: """Parameter for the response.""" + def __init__(self, parameter): """Initialize the parameter.""" self.bool_param = parameter class Parameters: """Parameters for the response.""" + def __init__(self, parameters): """Initialize the parameters.""" self.parameters = parameters @@ -53,20 +57,23 @@ def __init__(self, outputs=None, parameters=None): if outputs is None: outputs = {"outputs": []} self._rest_outputs = outputs - self.outputs = [InferOutput(response['name']) for response in outputs['outputs']] + self.outputs = [ + InferOutput(response["name"]) for response in outputs["outputs"] + ] self.parameters = InferResponse.Parameters(parameters) def get_response(self): """Get the response.""" return self - + def as_numpy(self, name): """Get the response as numpy.""" - for response in self._rest_outputs['outputs']: - if response['name'] == name: - return np.array(response['data']) + for response in self._rest_outputs["outputs"]: + if response["name"] == name: + return np.array(response["data"]) return None + def _array_to_first_element(arr): """ Convert a NumPy array to its first element if it contains only one element. @@ -86,23 +93,23 @@ def _array_to_first_element(arr): else: raise ValueError("Array contains more than one element") + def _process_chunk(chunk): """Process the chunk of data received.""" # Decode the byte string to a regular string - chunk_str = chunk.decode('utf-8') - + chunk_str = chunk.decode("utf-8") + # Strip the "data: " prefix - if chunk_str.startswith('data: '): - chunk_str = chunk_str[len('data: '):] - + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: ") :] + # Load the JSON string into a Python dictionary chunk_json = json.loads(chunk_str) return chunk_json -class InferenceServerClient(): - """Client to perform http communication with the Triton. - """ +class InferenceServerClient: + """Client to perform http communication with the Triton.""" def __init__(self, url, **kwargs): self.url = url @@ -111,58 +118,69 @@ def __init__(self, url, **kwargs): self._event = threading.Event() self._exception = None - def is_server_ready(self): - """Check if the server is ready. - """ + """Check if the server is ready.""" response = requests.get("http://" + self.url + "/v2/health/live") return response.status_code == 200 - + def is_server_live(self): - """Check if the server is ready. - """ + """Check if the server is ready.""" response = requests.get("http://" + self.url + "/v2/health/ready") return response.status_code == 200 - + def is_model_ready(self, model_name, model_version): - """Check if the model is ready. - """ + """Check if the model is ready.""" model_version = model_version if model_version else "1" - request = "http://" + self.url + "/v2/models/{}/versions/{}/ready".format(model_name, model_version) + request = ( + "http://" + + self.url + + "/v2/models/{}/versions/{}/ready".format(model_name, model_version) + ) response = requests.get(request) return response.status_code == 200 - + def get_model_config(self, model_name, model_version): - """Get the model configuration. - """ + """Get the model configuration.""" model_version = model_version if model_version else "1" - request = "http://" + self.url + "/v2/models/{}/versions/{}/config".format(model_name, model_version) + request = ( + "http://" + + self.url + + "/v2/models/{}/versions/{}/config".format(model_name, model_version) + ) response = requests.get(request) return response.json() - - def infer(self, model_name, model_version, inputs, headers, outputs, request_id, parameters): - """Perform inference. - """ + def infer( + self, + model_name, + model_version, + inputs, + headers, + outputs, + request_id, + parameters, + ): + """Perform inference.""" model_version = model_version if model_version else "1" - request = "http://" + self.url + "/v2/models/{}/versions/{}/infer".format(model_name, model_version) + request = ( + "http://" + + self.url + + "/v2/models/{}/versions/{}/infer".format(model_name, model_version) + ) inputs_for_json = [input_value.to_dict() for input_value in inputs] - ## TODO: Support setting outputs, request_id, parameters and headers response = requests.post(request, json={"inputs": inputs_for_json}) return InferResponse(response.json()) - + def start_stream(self, callback): - """Start the stream. - """ + """Start the stream.""" self._callback = callback - - + def process_chunk(self, chunk, error): """Process the chunk of data received.""" if self._callback: @@ -173,10 +191,14 @@ def stream_request(self, model_name, model_version, inputs_for_stream): headers = {"Content-Type": "application/json"} try: - with requests.post(request_url, json=inputs_for_stream, headers=headers, stream=True) as response: + with requests.post( + request_url, json=inputs_for_stream, headers=headers, stream=True + ) as response: if response.status_code != 200: _LOGGER.debug(f"Request failed with status code: {response}") - self._exception = Exception(f"Request failed with status code: {response.status_code}") + self._exception = Exception( + f"Request failed with status code: {response.status_code}" + ) self._event.set() return @@ -184,33 +206,55 @@ def stream_request(self, model_name, model_version, inputs_for_stream): try: for line in response.iter_lines(): - if line: # Filter out keep-alive new lines + if line: # Filter out keep-alive new lines response_json = _process_chunk(line) outputs = [] for key, value in response_json.items(): if key not in ["model_name", "model_version"]: outputs.append({"name": key, "data": [value]}) outputs_struct = {"outputs": outputs} - response = InferResponse(outputs_struct, parameters={"triton_final_response": False}) + response = InferResponse( + outputs_struct, + parameters={"triton_final_response": False}, + ) self.process_chunk(response, None) except Exception as e: - _LOGGER.debug(f"Some error occurred while processing the response: {e}") + _LOGGER.debug( + f"Some error occurred while processing the response: {e}" + ) self.process_chunk(None, e) - response = InferResponse(outputs=None, parameters={"triton_final_response": True}) + response = InferResponse( + outputs=None, parameters={"triton_final_response": True} + ) self.process_chunk(response, None) except Exception as e: _LOGGER.debug(f"Some error occurred while processing the response: {e}") self._exception = e self._event.set() - def async_stream_infer(self, model_name, model_version, inputs, outputs, request_id, enable_empty_final_response, **kwargs): + def async_stream_infer( + self, + model_name, + model_version, + inputs, + outputs, + request_id, + enable_empty_final_response, + **kwargs, + ): """Perform inference.""" model_version = model_version if model_version else "1" - inputs_for_stream = {input_value.name(): _array_to_first_element(input_value._np_data) for input_value in inputs} + inputs_for_stream = { + input_value.name(): _array_to_first_element(input_value._np_data) + for input_value in inputs + } self._event.clear() self._exception = None - self._stream_thread = threading.Thread(target=self.stream_request, args=(model_name, model_version, inputs_for_stream)) + self._stream_thread = threading.Thread( + target=self.stream_request, + args=(model_name, model_version, inputs_for_stream), + ) self._stream_thread.start() self._event.wait() # Block until the first 200 response or error is returned @@ -221,13 +265,12 @@ def close(self): """Close the stream and join the thread.""" if self._stream is not None: self._stream.join() - + def close(self): - """Close the client. - """ + """Close the client.""" pass - -class InferRequestedOutput(): + +class InferRequestedOutput: def __init__(self, name): self.name = name diff --git a/src/python/library/tritonclient/hl/lw/infer_input.py b/src/python/library/tritonclient/hl/lw/infer_input.py old mode 100755 new mode 100644 index 2c04d4d37..54c4f887e --- a/src/python/library/tritonclient/hl/lw/infer_input.py +++ b/src/python/library/tritonclient/hl/lw/infer_input.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import requests import numpy as np - +import requests + + def raise_error(message): """Raise an InferenceServerException with the specified message.""" raise Exception(message) + def triton_to_np_dtype(dtype): """Converts a Triton dtype to a numpy dtype.""" if dtype == "BOOL": @@ -49,6 +51,7 @@ def triton_to_np_dtype(dtype): return np.object_ return None + def serialize_byte_tensor(input_tensor): """ Serializes a bytes tensor into a flat numpy array of length prepended @@ -104,6 +107,7 @@ def serialize_byte_tensor(input_tensor): flattened_array = np.ascontiguousarray(flattened_array, dtype=np.object_) return flattened_array + def np_to_triton_dtype(np_dtype): """Converts a numpy dtype to a Triton dtype.""" if np_dtype == bool: @@ -224,7 +228,7 @@ def set_data_from_numpy(self, input_tensor): """ self._np_data = input_tensor - binary_data=False + binary_data = False if not isinstance(input_tensor, (np.ndarray,)): raise_error("input_tensor must be a numpy array") @@ -312,7 +316,6 @@ def set_data_from_numpy(self, input_tensor): self._parameters["binary_data_size"] = len(self._raw_data) return self - def _get_binary_data(self): """Returns the raw binary data if available @@ -350,8 +353,8 @@ def to_dict(self): "datatype": self.datatype(), "data": self._get_tensor()["data"], } - -class InferRequestedOutput(): + +class InferRequestedOutput: def __init__(self, name): self.name = name diff --git a/src/python/library/tritonclient/hl/lw/utils.py b/src/python/library/tritonclient/hl/lw/utils.py old mode 100755 new mode 100644 index 0a15b1911..f469c2965 --- a/src/python/library/tritonclient/hl/lw/utils.py +++ b/src/python/library/tritonclient/hl/lw/utils.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. + class InferenceServerException(Exception): - pass \ No newline at end of file + pass diff --git a/src/python/library/tritonclient/hl/model_config/model_config.py b/src/python/library/tritonclient/hl/model_config/model_config.py index 9d446478e..c42386bbe 100644 --- a/src/python/library/tritonclient/hl/model_config/model_config.py +++ b/src/python/library/tritonclient/hl/model_config/model_config.py @@ -21,7 +21,7 @@ import dataclasses -#from tritonclient.hl.model_config import DynamicBatcher +# from tritonclient.hl.model_config import DynamicBatcher @dataclasses.dataclass @@ -38,6 +38,6 @@ class ModelConfig: batching: bool = True max_batch_size: int = 4 - #batcher: DynamicBatcher = dataclasses.field(default_factory=DynamicBatcher) + # batcher: DynamicBatcher = dataclasses.field(default_factory=DynamicBatcher) response_cache: bool = False decoupled: bool = False diff --git a/src/python/library/tritonclient/hl/model_config/parser.py b/src/python/library/tritonclient/hl/model_config/parser.py index d5e669734..c7a465df8 100644 --- a/src/python/library/tritonclient/hl/model_config/parser.py +++ b/src/python/library/tritonclient/hl/model_config/parser.py @@ -33,13 +33,19 @@ class ModelConfig. import numpy as np from google.protobuf import json_format, text_format # pytype: disable=pyi-error - from tritonclient.hl.lw.infer_input import triton_to_np_dtype -#from pytriton.exceptions import PyTritonModelConfigError - from .common import QueuePolicy, TimeoutAction -from .triton_model_config import DeviceKind, DynamicBatcher, ResponseCache, TensorSpec, TritonModelConfig +from .triton_model_config import ( + DeviceKind, + DynamicBatcher, + ResponseCache, + TensorSpec, + TritonModelConfig, +) + +# from pytriton.exceptions import PyTritonModelConfigError + # try: # import tritonclient.hl.lw.grpc as grpc_client @@ -68,7 +74,9 @@ def from_dict(cls, model_config_dict: Dict) -> TritonModelConfig: Returns: A ModelConfig object with data parsed from the dictionary """ - LOGGER.debug(f"Parsing Triton config model from dict: \n{json.dumps(model_config_dict, indent=4)}") + LOGGER.debug( + f"Parsing Triton config model from dict: \n{json.dumps(model_config_dict, indent=4)}" + ) if model_config_dict.get("max_batch_size", 0) > 0: batching = True @@ -82,34 +90,46 @@ def from_dict(cls, model_config_dict: Dict) -> TritonModelConfig: batcher = None instance_group = { - DeviceKind(entry["kind"]): entry.get("count") for entry in model_config_dict.get("instance_group", []) + DeviceKind(entry["kind"]): entry.get("count") + for entry in model_config_dict.get("instance_group", []) } - decoupled = model_config_dict.get("model_transaction_policy", {}).get("decoupled", False) + decoupled = model_config_dict.get("model_transaction_policy", {}).get( + "decoupled", False + ) backend_parameters_config = model_config_dict.get("parameters", []) if isinstance(backend_parameters_config, list): # If the backend_parameters_config is a list of strings, use them as keys with empty values - LOGGER.debug(f"backend_parameters_config is a list of strings: {backend_parameters_config}") + LOGGER.debug( + f"backend_parameters_config is a list of strings: {backend_parameters_config}" + ) backend_parameters = {name: "" for name in backend_parameters_config} elif isinstance(backend_parameters_config, dict): # If the backend_parameters_config is a dictionary, use the key and "string_value" fields as key-value pairs - LOGGER.debug(f"backend_parameters_config is a dictionary: {backend_parameters_config}") + LOGGER.debug( + f"backend_parameters_config is a dictionary: {backend_parameters_config}" + ) backend_parameters = { - name: backend_parameters_config[name]["string_value"] for name in backend_parameters_config + name: backend_parameters_config[name]["string_value"] + for name in backend_parameters_config } else: # Otherwise, raise an error LOGGER.error( f"Invalid type {type(backend_parameters_config)} for backend_parameters_config: {backend_parameters_config}" ) - raise TypeError(f"Invalid type for backend_parameters_config: {type(backend_parameters_config)}") + raise TypeError( + f"Invalid type for backend_parameters_config: {type(backend_parameters_config)}" + ) inputs = [ - cls.rewrite_io_spec(item, "input", idx) for idx, item in enumerate(model_config_dict.get("input", [])) + cls.rewrite_io_spec(item, "input", idx) + for idx, item in enumerate(model_config_dict.get("input", [])) ] or None outputs = [ - cls.rewrite_io_spec(item, "output", idx) for idx, item in enumerate(model_config_dict.get("output", [])) + cls.rewrite_io_spec(item, "output", idx) + for idx, item in enumerate(model_config_dict.get("output", [])) ] or None response_cache_config = model_config_dict.get("response_cache") @@ -147,9 +167,13 @@ def from_file(cls, *, config_path: pathlib.Path) -> TritonModelConfig: with config_path.open("r") as config_file: payload = config_file.read() - model_config_proto = text_format.Parse(payload, model_config_pb2.ModelConfig()) + model_config_proto = text_format.Parse( + payload, model_config_pb2.ModelConfig() + ) - model_config_dict = json_format.MessageToDict(model_config_proto, preserving_proto_field_name=True) + model_config_dict = json_format.MessageToDict( + model_config_proto, preserving_proto_field_name=True + ) return ModelConfigParser.from_dict(model_config_dict=model_config_dict) @classmethod @@ -166,11 +190,15 @@ def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec: """ name = item.get("name") if not name: - raise PyTritonModelConfigError(f"Name for {io_type} at index {idx} not provided.") + raise PyTritonModelConfigError( + f"Name for {io_type} at index {idx} not provided." + ) data_type = item.get("data_type") if not data_type: - raise PyTritonModelConfigError(f"Data type for {io_type} with name `{name}` not defined.") + raise PyTritonModelConfigError( + f"Data type for {io_type} with name `{name}` not defined." + ) data_type_val = data_type.split("_") if len(data_type_val) != 2: @@ -185,18 +213,24 @@ def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec: else: dtype = triton_to_np_dtype(data_type) if dtype is None: - raise PyTritonModelConfigError(f"Unsupported data type `{data_type}` for {io_type} with name `{name}`") + raise PyTritonModelConfigError( + f"Unsupported data type `{data_type}` for {io_type} with name `{name}`" + ) dtype = np.dtype("bool") if dtype is bool else dtype dims = item.get("dims", []) if not dims: - raise PyTritonModelConfigError(f"Dimension for {io_type} with name `{name}` not defined.") + raise PyTritonModelConfigError( + f"Dimension for {io_type} with name `{name}` not defined." + ) shape = tuple(int(s) for s in dims) optional = item.get("optional", False) - return TensorSpec(name=item["name"], shape=shape, dtype=dtype, optional=optional) + return TensorSpec( + name=item["name"], shape=shape, dtype=dtype, optional=optional + ) @classmethod def _parse_dynamic_batching(cls, dynamic_batching_config: Dict) -> DynamicBatcher: @@ -209,36 +243,62 @@ def _parse_dynamic_batching(cls, dynamic_batching_config: Dict) -> DynamicBatche DynamicBatcher object with configuration """ default_queue_policy = None - default_queue_policy_config = dynamic_batching_config.get("default_queue_policy") + default_queue_policy_config = dynamic_batching_config.get( + "default_queue_policy" + ) if default_queue_policy_config: default_queue_policy = QueuePolicy( timeout_action=TimeoutAction( - default_queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value) + default_queue_policy_config.get( + "timeout_action", TimeoutAction.REJECT.value + ) + ), + default_timeout_microseconds=int( + default_queue_policy_config.get("default_timeout_microseconds", 0) + ), + allow_timeout_override=bool( + default_queue_policy_config.get("allow_timeout_override", False) + ), + max_queue_size=int( + default_queue_policy_config.get("max_queue_size", 0) ), - default_timeout_microseconds=int(default_queue_policy_config.get("default_timeout_microseconds", 0)), - allow_timeout_override=bool(default_queue_policy_config.get("allow_timeout_override", False)), - max_queue_size=int(default_queue_policy_config.get("max_queue_size", 0)), ) priority_queue_policy = None - priority_queue_policy_config = dynamic_batching_config.get("priority_queue_policy") + priority_queue_policy_config = dynamic_batching_config.get( + "priority_queue_policy" + ) if priority_queue_policy_config: priority_queue_policy = {} for priority, queue_policy_config in priority_queue_policy_config.items(): queue_policy = QueuePolicy( - timeout_action=TimeoutAction(queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)), - default_timeout_microseconds=int(queue_policy_config.get("default_timeout_microseconds", 0)), - allow_timeout_override=bool(queue_policy_config.get("allow_timeout_override", False)), + timeout_action=TimeoutAction( + queue_policy_config.get( + "timeout_action", TimeoutAction.REJECT.value + ) + ), + default_timeout_microseconds=int( + queue_policy_config.get("default_timeout_microseconds", 0) + ), + allow_timeout_override=bool( + queue_policy_config.get("allow_timeout_override", False) + ), max_queue_size=int(queue_policy_config.get("max_queue_size", 0)), ) priority_queue_policy[int(priority)] = queue_policy batcher = DynamicBatcher( preferred_batch_size=dynamic_batching_config.get("preferred_batch_size"), - max_queue_delay_microseconds=int(dynamic_batching_config.get("max_queue_delay_microseconds", 0)), - preserve_ordering=bool(dynamic_batching_config.get("preserve_ordering", False)), + max_queue_delay_microseconds=int( + dynamic_batching_config.get("max_queue_delay_microseconds", 0) + ), + preserve_ordering=bool( + dynamic_batching_config.get("preserve_ordering", False) + ), priority_levels=int(dynamic_batching_config.get("priority_levels", 0)), - default_priority_level=int(dynamic_batching_config.get("default_priority_level", 0)), + default_priority_level=int( + dynamic_batching_config.get("default_priority_level", 0) + ), default_queue_policy=default_queue_policy, priority_queue_policy=priority_queue_policy, ) diff --git a/src/python/library/tritonclient/hl/model_config/tensor.py b/src/python/library/tritonclient/hl/model_config/tensor.py index ded9050c6..32584d7dc 100644 --- a/src/python/library/tritonclient/hl/model_config/tensor.py +++ b/src/python/library/tritonclient/hl/model_config/tensor.py @@ -54,4 +54,6 @@ class Tensor: def __post_init__(self): """Override object values on post init or field override.""" if isinstance(self.dtype, np.dtype): - object.__setattr__(self, "dtype", self.dtype.type) # pytype: disable=attribute-error + object.__setattr__( + self, "dtype", self.dtype.type + ) # pytype: disable=attribute-error diff --git a/src/python/library/tritonclient/hl/model_config/triton_model_config.py b/src/python/library/tritonclient/hl/model_config/triton_model_config.py index 87aa276c3..7953e1dc3 100644 --- a/src/python/library/tritonclient/hl/model_config/triton_model_config.py +++ b/src/python/library/tritonclient/hl/model_config/triton_model_config.py @@ -55,7 +55,9 @@ class TritonModelConfig: max_batch_size: int = 4 batching: bool = True batcher: Optional[DynamicBatcher] = None - instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {}) + instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field( + default_factory=lambda: {} + ) decoupled: bool = False backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) inputs: Optional[Sequence[TensorSpec]] = None diff --git a/src/python/library/tritonclient/hl/triton_model_config.py b/src/python/library/tritonclient/hl/triton_model_config.py index 8949a409f..f143d46e1 100644 --- a/src/python/library/tritonclient/hl/triton_model_config.py +++ b/src/python/library/tritonclient/hl/triton_model_config.py @@ -17,7 +17,6 @@ from typing import Dict, Optional, Sequence, Type, Union import numpy as np - from tritonclient.hl.common import DeviceKind, DynamicBatcher @@ -55,7 +54,9 @@ class TritonModelConfig: max_batch_size: int = 4 batching: bool = True batcher: Optional[DynamicBatcher] = None - instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {}) + instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field( + default_factory=lambda: {} + ) decoupled: bool = False backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) inputs: Optional[Sequence[TensorSpec]] = None diff --git a/src/python/library/tritonclient/hl/utils.py b/src/python/library/tritonclient/hl/utils.py index 7b4f98cb4..ec71e5d98 100644 --- a/src/python/library/tritonclient/hl/utils.py +++ b/src/python/library/tritonclient/hl/utils.py @@ -21,18 +21,21 @@ import time import urllib import warnings -from typing import Optional, Any +from typing import Any, Optional import tritonclient.hl.lw.grpc import tritonclient.hl.lw.http -#import tritonclient.http.aio -from grpc import RpcError -from tritonclient.hl.lw.utils import InferenceServerException -from tritonclient.hl.exceptions import PyTritonClientInvalidUrlError, PyTritonClientTimeoutError -from tritonclient.hl.warnings import NotSupportedTimeoutWarning +# import tritonclient.http.aio +from grpc import RpcError from tritonclient.hl.constants import DEFAULT_GRPC_PORT, DEFAULT_HTTP_PORT +from tritonclient.hl.exceptions import ( + PyTritonClientInvalidUrlError, + PyTritonClientTimeoutError, +) +from tritonclient.hl.lw.utils import InferenceServerException from tritonclient.hl.model_config.parser import ModelConfigParser +from tritonclient.hl.warnings import NotSupportedTimeoutWarning _LOGGER = logging.getLogger(__name__) @@ -68,11 +71,15 @@ def parse_http_response(models): models_states = {} _LOGGER.debug("Parsing model repository index entries:") for model in models: - _LOGGER.debug(f" name={model.get('name')} version={model.get('version')} state={model.get('state')}") + _LOGGER.debug( + f" name={model.get('name')} version={model.get('version')} state={model.get('state')}" + ) if not model.get("version"): continue - model_state = ModelState(model["state"]) if model.get("state") else ModelState.LOADING + model_state = ( + ModelState(model["state"]) if model.get("state") else ModelState.LOADING + ) models_states[(model["name"], model["version"])] = model_state return models_states @@ -83,7 +90,9 @@ def parse_grpc_response(models): models_states = {} _LOGGER.debug("Parsing model repository index entries:") for model in models: - _LOGGER.debug(f" name={model.name} version={model.version} state={model.state}") + _LOGGER.debug( + f" name={model.name} version={model.version} state={model.state}" + ) if not model.version: continue @@ -120,12 +129,16 @@ def get_model_state( if model_version is None: requested_model_states = { - version: state for (name, version), state in models_states.items() if name == model_name + version: state + for (name, version), state in models_states.items() + if name == model_name } if not requested_model_states: return ModelState.UNAVAILABLE else: - requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0])) + requested_model_states = sorted( + requested_model_states.items(), key=lambda item: int(item[0]) + ) _latest_version, latest_version_state = requested_model_states[-1] return latest_version_state else: @@ -165,7 +178,9 @@ def get_model_config( PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout. PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. """ - wait_for_model_ready(client, model_name=model_name, model_version=model_version, timeout_s=timeout_s) + wait_for_model_ready( + client, model_name=model_name, model_version=model_version, timeout_s=timeout_s + ) model_version = model_version or "" @@ -216,7 +231,9 @@ def wait_for_server_ready( Raises: PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout. """ - timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S + timeout_s = ( + timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S + ) should_finish_before_s = time.time() + timeout_s _warn_on_too_big_network_timeout(client, timeout_s) @@ -225,7 +242,11 @@ def _is_server_ready(): return client.is_server_ready() and client.is_server_live() except InferenceServerException: return False - except (RpcError, ConnectionError, socket.gaierror): # GRPC and HTTP clients raises these errors + except ( + RpcError, + ConnectionError, + socket.gaierror, + ): # GRPC and HTTP clients raises these errors return False except Exception as e: _LOGGER.exception(f"Exception while checking server readiness: {e}") @@ -238,7 +259,9 @@ def _is_server_ready(): time.sleep(min(1.0, timeout_s)) is_server_ready = _is_server_ready() if not is_server_ready and time.time() >= should_finish_before_s: - raise PyTritonClientTimeoutError("Waiting for server to be ready timed out.") + raise PyTritonClientTimeoutError( + "Waiting for server to be ready timed out." + ) def wait_for_model_ready( @@ -263,12 +286,16 @@ def wait_for_model_ready( """ model_version = model_version or "" model_version_msg = model_version or LATEST_MODEL_VERSION - timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S + timeout_s = ( + timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S + ) should_finish_before_s = time.time() + timeout_s wait_for_server_ready(client, timeout_s=timeout_s) timeout_s = max(0.0, should_finish_before_s - time.time()) - _LOGGER.debug(f"Waiting for model {model_name}/{model_version_msg} to be ready (timeout={timeout_s})") + _LOGGER.debug( + f"Waiting for model {model_name}/{model_version_msg} to be ready (timeout={timeout_s})" + ) is_model_ready = client.is_model_ready(model_name, model_version) while not is_model_ready: time.sleep(min(1.0, timeout_s)) @@ -296,11 +323,17 @@ def create_client_from_url(url: str, network_timeout_s: Optional[float] = None) PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. """ url = TritonUrl.from_url(url) - triton_client_lib = {"grpc": tritonclient.grpc, "http": tritonclient.http}[url.scheme] + triton_client_lib = {"grpc": tritonclient.grpc, "http": tritonclient.http}[ + url.scheme + ] if url.scheme == "grpc": # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout - network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s + network_timeout_s = ( + _DEFAULT_NETWORK_TIMEOUT_S + if network_timeout_s is None + else network_timeout_s + ) warnings.warn( f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.", NotSupportedTimeoutWarning, @@ -312,12 +345,19 @@ def create_client_from_url(url: str, network_timeout_s: Optional[float] = None) triton_client_init_kwargs.update( **{ "grpc": {}, - "http": {"connection_timeout": network_timeout_s, "network_timeout": network_timeout_s}, + "http": { + "connection_timeout": network_timeout_s, + "network_timeout": network_timeout_s, + }, }[url.scheme] ) - _LOGGER.debug(f"Creating InferenceServerClient for {url.with_scheme} with {triton_client_init_kwargs}") - return triton_client_lib.InferenceServerClient(url.without_scheme, **triton_client_init_kwargs) + _LOGGER.debug( + f"Creating InferenceServerClient for {url.with_scheme} with {triton_client_init_kwargs}" + ) + return triton_client_lib.InferenceServerClient( + url.without_scheme, **triton_client_init_kwargs + ) @dataclasses.dataclass @@ -351,24 +391,39 @@ def from_url(cls, url): TritonUrl object with scheme, hostname and port. """ if not isinstance(url, str): - raise PyTritonClientInvalidUrlError(f"Invalid url {url}. Url must be a string.") + raise PyTritonClientInvalidUrlError( + f"Invalid url {url}. Url must be a string." + ) try: parsed_url = urllib.parse.urlparse(url) # change in py3.9+ # https://github.com/python/cpython/commit/5a88d50ff013a64fbdb25b877c87644a9034c969 - if sys.version_info < (3, 9) and not parsed_url.scheme and "://" in parsed_url.path: - raise ValueError(f"Invalid url {url}. Only grpc and http are supported.") + if ( + sys.version_info < (3, 9) + and not parsed_url.scheme + and "://" in parsed_url.path + ): + raise ValueError( + f"Invalid url {url}. Only grpc and http are supported." + ) if (not parsed_url.scheme and "://" not in parsed_url.path) or ( - sys.version_info >= (3, 9) and parsed_url.scheme and not parsed_url.netloc + sys.version_info >= (3, 9) + and parsed_url.scheme + and not parsed_url.netloc ): _LOGGER.debug(f"Adding http scheme to {url}") parsed_url = urllib.parse.urlparse(f"http://{url}") scheme = parsed_url.scheme.lower() if scheme not in ["grpc", "http"]: - raise ValueError(f"Invalid scheme {scheme}. Only grpc and http are supported.") + raise ValueError( + f"Invalid scheme {scheme}. Only grpc and http are supported." + ) - port = parsed_url.port or {"grpc": DEFAULT_GRPC_PORT, "http": DEFAULT_HTTP_PORT}[scheme] + port = ( + parsed_url.port + or {"grpc": DEFAULT_GRPC_PORT, "http": DEFAULT_HTTP_PORT}[scheme] + ) except ValueError as e: raise PyTritonClientInvalidUrlError(f"Invalid url {url}") from e return cls(scheme, parsed_url.hostname, port) From c3136ef3f8e939fbb21aa06ebae2533cbb5ed147 Mon Sep 17 00:00:00 2001 From: Piotr Marcinkiewicz Date: Mon, 17 Jun 2024 19:42:05 +0200 Subject: [PATCH 4/4] Fix CodeQL error for Exception and Client --- src/python/library/tests/test_client.py | 2 +- src/python/library/tritonclient/hl/lw/http.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/python/library/tests/test_client.py b/src/python/library/tests/test_client.py index 3403322b0..25fc458f6 100644 --- a/src/python/library/tests/test_client.py +++ b/src/python/library/tests/test_client.py @@ -31,7 +31,7 @@ class TestClient(unittest.TestCase): def test_client(self): - Client() + Client("localhost:8000") if __name__ == "__main__": diff --git a/src/python/library/tritonclient/hl/lw/http.py b/src/python/library/tritonclient/hl/lw/http.py index 08c46fbb1..3befd4d4b 100644 --- a/src/python/library/tritonclient/hl/lw/http.py +++ b/src/python/library/tritonclient/hl/lw/http.py @@ -258,8 +258,10 @@ def async_stream_infer( self._stream_thread.start() self._event.wait() # Block until the first 200 response or error is returned - if self._exception: + if isinstance(self._exception, Exception): raise self._exception + else: + raise Exception("An unknown error occurred") def close(self): """Close the stream and join the thread."""