Skip to content

Commit 07f572e

Browse files
committed
Add TritonModelInput with optional
1 parent d45dd31 commit 07f572e

File tree

4 files changed

+64
-27
lines changed

4 files changed

+64
-27
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
hooks:
2929
- id: flake8
3030
types: [python]
31-
args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503"]
31+
args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503,E501"]
3232
- repo: https://github.com/shellcheck-py/shellcheck-py
3333
rev: v0.9.0.5
3434
hooks:

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ classifiers =
4343
zip_safe = False
4444
include_package_data = True
4545
packages = find:
46+
package_dir =
47+
=.
4648
install_requires =
4749
tritonclient[all]>=2.21.0
4850
protobuf>=3.5.0

tritony/helpers.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from attrs import define
1111
from tritonclient import grpc as grpcclient
1212
from tritonclient import http as httpclient
13+
from tritonclient.grpc import model_config_pb2
1314

1415

1516
class TritonProtocol(Enum):
@@ -31,13 +32,32 @@ def dict_to_attr(obj: dict[str, Any]) -> SimpleNamespace:
3132
return json.loads(json.dumps(obj), object_hook=lambda d: SimpleNamespace(**d))
3233

3334

35+
@define
36+
class TritonModelInput:
37+
"""
38+
Most of the fields are mapped to model_config_pb2.ModelInput(https://github.com/triton-inference-server/common/blob/a2de06f4c80b2c7b15469fa4d36e5f6445382bad/protobuf/model_config.proto#L317)
39+
40+
Commented fields are not used.
41+
"""
42+
43+
name: str
44+
dtype: str # data_type mapping to https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L163-L190
45+
46+
format: int = 0
47+
dims: list[int] = [] # dims
48+
49+
# reshape: list[int] = []
50+
# is_shape_tensor: bool = False
51+
# allow_ragged_batch: bool = False
52+
optional: bool = False
53+
54+
3455
@define
3556
class TritonModelSpec:
3657
name: str
3758

3859
max_batch_size: int
39-
input_name: list[str]
40-
input_dtype: list[str]
60+
model_input: list[TritonModelInput]
4161

4262
output_name: list[str]
4363

@@ -91,7 +111,7 @@ def get_triton_client(
91111
model_name: str,
92112
model_version: str,
93113
protocol: TritonProtocol,
94-
):
114+
) -> (int, list[TritonModelInput], list[str]):
95115
"""
96116
(required in)
97117
:param triton_client:
@@ -107,23 +127,43 @@ def get_triton_client(
107127

108128
args = dict(model_name=model_name, model_version=model_version)
109129

110-
model_metadata = triton_client.get_model_metadata(**args)
111130
model_config = triton_client.get_model_config(**args)
112131
if protocol is TritonProtocol.http:
113-
model_metadata = dict_to_attr(model_metadata)
114132
model_config = dict_to_attr(model_config)
115133
elif protocol is TritonProtocol.grpc:
116134
model_config = model_config.config
117135

118-
max_batch_size, input_name_list, output_name_list, dtype_list = parse_model(model_metadata, model_config)
136+
max_batch_size, input_list, output_name_list = parse_model(model_config)
137+
138+
return max_batch_size, input_list, output_name_list
139+
119140

120-
return max_batch_size, input_name_list, output_name_list, dtype_list
141+
def parse_model_input(
142+
model_input: model_config_pb2.ModelInput | SimpleNamespace,
143+
) -> TritonModelInput:
144+
"""
145+
https://github.com/triton-inference-server/common/blob/r23.08/protobuf/model_config.proto#L317-L412
146+
"""
147+
RAW_DTYPE = model_input.data_type
148+
if isinstance(model_input.data_type, int):
149+
RAW_DTYPE = model_config_pb2.DataType.Name(RAW_DTYPE)
150+
RAW_DTYPE = RAW_DTYPE.strip("TYPE_")
151+
152+
if RAW_DTYPE == "STRING":
153+
RAW_DTYPE = "BYTES" # https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L188-L189
154+
return TritonModelInput(
155+
name=model_input.name,
156+
dims=model_input.dims,
157+
dtype=RAW_DTYPE,
158+
optional=model_input.optional,
159+
)
121160

122161

123-
def parse_model(model_metadata, model_config):
162+
def parse_model(
163+
model_config: model_config_pb2.ModelConfig | SimpleNamespace,
164+
) -> (int, list[TritonModelInput], list[str]):
124165
return (
125166
model_config.max_batch_size,
126-
[input_metadata.name for input_metadata in model_metadata.inputs],
127-
[output_metadata.name for output_metadata in model_metadata.outputs],
128-
[input_metadata.datatype for input_metadata in model_metadata.inputs],
167+
[parse_model_input(model_config_input) for model_config_input in model_config.input],
168+
[model_config_output.name for model_config_output in model_config.output],
129169
)

tritony/tools.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import os
88
import time
9-
import warnings
109
from concurrent.futures import ThreadPoolExecutor
1110
from typing import Any, Dict, List, Optional, Union
1211

@@ -198,14 +197,6 @@ def triton_client(self):
198197
def default_model_spec(self):
199198
return self.model_specs[self.default_model]
200199

201-
@property
202-
def input_name_list(self):
203-
warnings.warn(
204-
"input_name_list is deprecated, please use 'default_model_spec.input_name' instead", DeprecationWarning
205-
)
206-
207-
return self.default_model_spec.input_name
208-
209200
def __del__(self):
210201
# Not supporting streaming
211202
# if self.flag.protocol is TritonProtocol.grpc and self.flag.streaming and hasattr(self, "triton_client"):
@@ -223,15 +214,14 @@ def _renew_triton_client(self, triton_client, model_name: str | None = None, mod
223214
triton_client.is_server_ready()
224215
triton_client.is_model_ready(model_name, model_version)
225216

226-
(max_batch_size, input_name_list, output_name_list, dtype_list) = get_triton_client(
217+
(max_batch_size, input_list, output_name_list) = get_triton_client(
227218
triton_client, model_name=model_name, model_version=model_version, protocol=self.flag.protocol
228219
)
229220

230221
self.model_specs[(model_name, model_version)] = TritonModelSpec(
231222
name=model_name,
232223
max_batch_size=max_batch_size,
233-
input_name=input_name_list,
234-
input_dtype=dtype_list,
224+
model_input=input_list,
235225
output_name=output_name_list,
236226
)
237227

@@ -257,7 +247,12 @@ def __call__(
257247
if type(sequences_or_dict) in [list, np.ndarray]:
258248
sequences_list = [sequences_or_dict]
259249
elif type(sequences_or_dict) is dict:
260-
sequences_list = [sequences_or_dict[input_name] for input_name in model_spec.input_name]
250+
sequences_list = [
251+
sequences_or_dict[model_input.name]
252+
for model_input in model_spec.model_input
253+
if model_input.optional is False # check required
254+
or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional
255+
]
261256

262257
return self._call_async(sequences_list, model_spec=model_spec)
263258

@@ -267,8 +262,8 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode
267262
else:
268263
client = httpclient
269264
infer_input_list = []
270-
for _input, _input_name, _dtype in zip(_input_list, model_spec.input_name, model_spec.input_dtype):
271-
infer_input = client.InferInput(_input_name, _input.shape, _dtype)
265+
for _input, _model_input in zip(_input_list, model_spec.model_input):
266+
infer_input = client.InferInput(_model_input.name, _input.shape, _model_input.dtype)
272267
infer_input.set_data_from_numpy(_input)
273268
infer_input_list.append(infer_input)
274269

0 commit comments

Comments
 (0)