6
6
import logging
7
7
import os
8
8
import time
9
+ import warnings
9
10
from concurrent .futures import ThreadPoolExecutor
10
11
from typing import Any , Dict , List , Optional , Union
11
12
@@ -166,10 +167,12 @@ def __init__(self, flag: TritonClientFlag):
166
167
self .__version__ = 1
167
168
168
169
self .flag = flag
170
+ self .default_model = (flag .model_name , flag .model_version )
169
171
self .model_specs = {}
170
172
self .is_async = self .flag .async_set
171
173
self .client_timeout = TRITON_CLIENT_TIMEOUT
172
174
self ._triton_client = None
175
+ self .triton_client
173
176
174
177
self .output_kwargs = {}
175
178
self .sent_count = 0
@@ -184,6 +187,18 @@ def triton_client(self):
184
187
self ._renew_triton_client (self ._triton_client )
185
188
return self ._triton_client
186
189
190
+ @property
191
+ def default_model_spec (self ):
192
+ return self .model_specs [self .default_model ]
193
+
194
+ @property
195
+ def input_name_list (self ):
196
+ warnings .warn (
197
+ "input_name_list is deprecated, please use 'default_model_spec.input_name' instead" , DeprecationWarning
198
+ )
199
+
200
+ return self .default_model_spec .input_name
201
+
187
202
def __del__ (self ):
188
203
# Not supporting streaming
189
204
# if self.flag.protocol is TritonProtocol.grpc and self.flag.streaming and hasattr(self, "triton_client"):
@@ -201,7 +216,7 @@ def _renew_triton_client(self, triton_client, model_name: str | None = None, mod
201
216
triton_client .is_server_ready ()
202
217
triton_client .is_model_ready (model_name , model_version )
203
218
204
- (max_batch_size , input_name_list , output_name_list , dtype_list , ) = get_triton_client (
219
+ (max_batch_size , input_name_list , output_name_list , dtype_list ) = get_triton_client (
205
220
triton_client , model_name = model_name , model_version = model_version , protocol = self .flag .protocol
206
221
)
207
222
0 commit comments