@@ -72,6 +72,7 @@ async def send_request_async(
72
72
done_event ,
73
73
triton_client : Union [grpcclient .InferenceServerClient , httpclient .InferenceServerClient ],
74
74
model_spec : TritonModelSpec ,
75
+ parameters : dict | None = None ,
75
76
):
76
77
ret = []
77
78
while True :
@@ -86,7 +87,7 @@ async def send_request_async(
86
87
try :
87
88
a_pred = await request_async (
88
89
inference_client .flag .protocol ,
89
- inference_client .build_triton_input (batch_data , model_spec ),
90
+ inference_client .build_triton_input (batch_data , model_spec , parameters = parameters ),
90
91
triton_client ,
91
92
timeout = inference_client .client_timeout ,
92
93
compression = inference_client .flag .compression_algorithm ,
@@ -232,6 +233,7 @@ def _get_request_id(self):
232
233
def __call__ (
233
234
self ,
234
235
sequences_or_dict : Union [List [Any ], Dict [str , List [Any ]]],
236
+ parameters : dict | None = None ,
235
237
model_name : str | None = None ,
236
238
model_version : str | None = None ,
237
239
):
@@ -254,9 +256,14 @@ def __call__(
254
256
or (model_input .optional is True and model_input .name in sequences_or_dict ) # check optional
255
257
]
256
258
257
- return self ._call_async (sequences_list , model_spec = model_spec )
259
+ return self ._call_async (sequences_list , model_spec = model_spec , parameters = parameters )
258
260
259
- def build_triton_input (self , _input_list : List [np .array ], model_spec : TritonModelSpec ):
261
+ def build_triton_input (
262
+ self ,
263
+ _input_list : List [np .array ],
264
+ model_spec : TritonModelSpec ,
265
+ parameters : dict | None = None ,
266
+ ):
260
267
if self .flag .protocol is TritonProtocol .grpc :
261
268
client = grpcclient
262
269
else :
@@ -278,19 +285,30 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode
278
285
request_id = str (request_id ),
279
286
model_version = model_spec .model_version ,
280
287
outputs = infer_requested_output ,
288
+ parameters = parameters ,
281
289
)
282
290
283
291
return request_input
284
292
285
- def _call_async (self , data : List [np .ndarray ], model_spec : TritonModelSpec ) -> Optional [np .ndarray ]:
286
- async_result = asyncio .run (self ._call_async_item (data = data , model_spec = model_spec ))
293
+ def _call_async (
294
+ self ,
295
+ data : List [np .ndarray ],
296
+ model_spec : TritonModelSpec ,
297
+ parameters : dict | None = None ,
298
+ ) -> Optional [np .ndarray ]:
299
+ async_result = asyncio .run (self ._call_async_item (data = data , model_spec = model_spec , parameters = parameters ))
287
300
288
301
if isinstance (async_result , Exception ):
289
302
raise async_result
290
303
291
304
return async_result
292
305
293
- async def _call_async_item (self , data : List [np .ndarray ], model_spec : TritonModelSpec ):
306
+ async def _call_async_item (
307
+ self ,
308
+ data : List [np .ndarray ],
309
+ model_spec : TritonModelSpec ,
310
+ parameters : dict | None = None ,
311
+ ):
294
312
current_grpc_async_tasks = []
295
313
296
314
try :
@@ -301,7 +319,9 @@ async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModel
301
319
current_grpc_async_tasks .append (generator )
302
320
303
321
predict_tasks = [
304
- asyncio .create_task (send_request_async (self , data_queue , done_event , self .triton_client , model_spec ))
322
+ asyncio .create_task (
323
+ send_request_async (self , data_queue , done_event , self .triton_client , model_spec , parameters )
324
+ )
305
325
for idx in range (ASYNC_TASKS )
306
326
]
307
327
current_grpc_async_tasks .extend (predict_tasks )
0 commit comments