Skip to content

Commit 2c0c822

Browse files
committed
[fix] support old interface with deprecation warn
1 parent ce5dc2e commit 2c0c822

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

tests/test_model_call.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,19 @@ def test_swithcing(protocol_and_port):
2929
sample_batched = np.random.rand(100, 100).astype(np.float32)
3030
client(sample_batched, model_name="sample_autobatching")
3131
print(f"Result: {np.isclose(result, sample).all()}")
32+
33+
34+
def test_with_input_name(protocol_and_port):
35+
protocol, port = protocol_and_port
36+
print(f"Testing {protocol}")
37+
38+
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
39+
40+
sample = np.random.rand(1, 100).astype(np.float32)
41+
result = client({client.input_name_list[0]: sample})
42+
print(f"Result: {np.isclose(result, sample).all()}")
43+
44+
sample = np.random.rand(100, 100).astype(np.float32)
45+
result = client({client.default_model_spec.input_name[0]: sample})
46+
47+
print(f"Result: {np.isclose(result, sample).all()}")

tritony/tools.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
import time
9+
import warnings
910
from concurrent.futures import ThreadPoolExecutor
1011
from typing import Any, Dict, List, Optional, Union
1112

@@ -166,10 +167,12 @@ def __init__(self, flag: TritonClientFlag):
166167
self.__version__ = 1
167168

168169
self.flag = flag
170+
self.default_model = (flag.model_name, flag.model_version)
169171
self.model_specs = {}
170172
self.is_async = self.flag.async_set
171173
self.client_timeout = TRITON_CLIENT_TIMEOUT
172174
self._triton_client = None
175+
self.triton_client
173176

174177
self.output_kwargs = {}
175178
self.sent_count = 0
@@ -184,6 +187,18 @@ def triton_client(self):
184187
self._renew_triton_client(self._triton_client)
185188
return self._triton_client
186189

190+
@property
191+
def default_model_spec(self):
192+
return self.model_specs[self.default_model]
193+
194+
@property
195+
def input_name_list(self):
196+
warnings.warn(
197+
"input_name_list is deprecated, please use 'default_model_spec.input_name' instead", DeprecationWarning
198+
)
199+
200+
return self.default_model_spec.input_name
201+
187202
def __del__(self):
188203
# Not supporting streaming
189204
# if self.flag.protocol is TritonProtocol.grpc and self.flag.streaming and hasattr(self, "triton_client"):
@@ -201,7 +216,7 @@ def _renew_triton_client(self, triton_client, model_name: str | None = None, mod
201216
triton_client.is_server_ready()
202217
triton_client.is_model_ready(model_name, model_version)
203218

204-
(max_batch_size, input_name_list, output_name_list, dtype_list,) = get_triton_client(
219+
(max_batch_size, input_name_list, output_name_list, dtype_list) = get_triton_client(
205220
triton_client, model_name=model_name, model_version=model_version, protocol=self.flag.protocol
206221
)
207222

0 commit comments

Comments
 (0)