Skip to content

Commit f071bc5

Browse files
committed
Support parameters on config.pbtxt
1 parent 0d343d7 commit f071bc5

File tree

7 files changed

+75
-20
lines changed

7 files changed

+75
-20
lines changed

model_repository/sample/1/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ def initialize(self, args):
1313
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
1414
]
1515

16+
parameters = self.model_config["parameters"]
17+
1618
def execute(self, requests):
1719
responses = [None for _ in requests]
1820
for idx, request in enumerate(requests):
19-
in_tensor = [item.as_numpy() for item in request.inputs()]
21+
current_add_value = int(json.loads(request.parameters()).get("add", 0))
22+
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()]
2023
out_tensor = [
2124
pb_utils.Tensor(output_name, x.astype(output_dtype))
2225
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)

model_repository/sample/config.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
name: "sample"
22
backend: "python"
33
max_batch_size: 0
4+
5+
parameters [
6+
{
7+
key: "add",
8+
value: { string_value: "0" }
9+
}
10+
]
11+
412
input [
513
{
614
name: "model_in"

model_repository/sample_autobatching/config.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
name: "sample_autobatching"
22
backend: "python"
33
max_batch_size: 2
4+
5+
parameters [
6+
{
7+
key: "add",
8+
value: { string_value: "0" }
9+
}
10+
]
11+
412
input [
513
{
614
name: "model_in"

model_repository/sample_multiple/config.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
name: "sample_multiple"
22
backend: "python"
33
max_batch_size: 2
4+
5+
parameters [
6+
{
7+
key: "add",
8+
value: { string_value: "0" }
9+
}
10+
]
11+
412
input [
513
{
614
name: "model_in0"

tests/test_model_call.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,36 @@ def protocol_and_port(request):
1616
return request.param
1717

1818

19-
def test_swithcing(protocol_and_port):
20-
protocol, port = protocol_and_port
21-
print(f"Testing {protocol}")
19+
def get_client(protocol, port):
20+
print(f"Testing {protocol}", flush=True)
21+
return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
22+
2223

23-
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
24+
def test_swithcing(protocol_and_port):
25+
client = get_client(*protocol_and_port)
2426

2527
sample = np.random.rand(1, 100).astype(np.float32)
2628
result = client(sample)
27-
print(f"Result: {np.isclose(result, sample).all()}")
29+
assert {np.isclose(result, sample).all()}
2830

2931
sample_batched = np.random.rand(100, 100).astype(np.float32)
3032
client(sample_batched, model_name="sample_autobatching")
31-
print(f"Result: {np.isclose(result, sample).all()}")
33+
assert {np.isclose(result, sample).all()}
3234

3335

3436
def test_with_input_name(protocol_and_port):
35-
protocol, port = protocol_and_port
36-
print(f"Testing {protocol}")
37-
38-
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
37+
client = get_client(*protocol_and_port)
3938

4039
sample = np.random.rand(100, 100).astype(np.float32)
4140
result = client({client.default_model_spec.model_input[0].name: sample})
41+
assert {np.isclose(result, sample).all()}
42+
43+
44+
def test_with_parameters(protocol_and_port):
45+
client = get_client(*protocol_and_port)
46+
47+
sample = np.random.rand(1, 100).astype(np.float32)
48+
ADD_VALUE = 1
49+
result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"})
4250

43-
print(f"Result: {np.isclose(result, sample).all()}")
51+
assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()}

tritony/tools.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ async def send_request_async(
7272
done_event,
7373
triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient],
7474
model_spec: TritonModelSpec,
75+
parameters: dict | None = None,
7576
):
7677
ret = []
7778
while True:
@@ -86,7 +87,7 @@ async def send_request_async(
8687
try:
8788
a_pred = await request_async(
8889
inference_client.flag.protocol,
89-
inference_client.build_triton_input(batch_data, model_spec),
90+
inference_client.build_triton_input(batch_data, model_spec, parameters=parameters),
9091
triton_client,
9192
timeout=inference_client.client_timeout,
9293
compression=inference_client.flag.compression_algorithm,
@@ -232,6 +233,7 @@ def _get_request_id(self):
232233
def __call__(
233234
self,
234235
sequences_or_dict: Union[List[Any], Dict[str, List[Any]]],
236+
parameters: dict | None = None,
235237
model_name: str | None = None,
236238
model_version: str | None = None,
237239
):
@@ -254,9 +256,14 @@ def __call__(
254256
or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional
255257
]
256258

257-
return self._call_async(sequences_list, model_spec=model_spec)
259+
return self._call_async(sequences_list, model_spec=model_spec, parameters=parameters)
258260

259-
def build_triton_input(self, _input_list: List[np.array], model_spec: TritonModelSpec):
261+
def build_triton_input(
262+
self,
263+
_input_list: List[np.array],
264+
model_spec: TritonModelSpec,
265+
parameters: dict | None = None,
266+
):
260267
if self.flag.protocol is TritonProtocol.grpc:
261268
client = grpcclient
262269
else:
@@ -278,19 +285,30 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode
278285
request_id=str(request_id),
279286
model_version=model_spec.model_version,
280287
outputs=infer_requested_output,
288+
parameters=parameters,
281289
)
282290

283291
return request_input
284292

285-
def _call_async(self, data: List[np.ndarray], model_spec: TritonModelSpec) -> Optional[np.ndarray]:
286-
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec))
293+
def _call_async(
294+
self,
295+
data: List[np.ndarray],
296+
model_spec: TritonModelSpec,
297+
parameters: dict | None = None,
298+
) -> Optional[np.ndarray]:
299+
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters))
287300

288301
if isinstance(async_result, Exception):
289302
raise async_result
290303

291304
return async_result
292305

293-
async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModelSpec):
306+
async def _call_async_item(
307+
self,
308+
data: List[np.ndarray],
309+
model_spec: TritonModelSpec,
310+
parameters: dict | None = None,
311+
):
294312
current_grpc_async_tasks = []
295313

296314
try:
@@ -301,7 +319,9 @@ async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModel
301319
current_grpc_async_tasks.append(generator)
302320

303321
predict_tasks = [
304-
asyncio.create_task(send_request_async(self, data_queue, done_event, self.triton_client, model_spec))
322+
asyncio.create_task(
323+
send_request_async(self, data_queue, done_event, self.triton_client, model_spec, parameters)
324+
)
305325
for idx in range(ASYNC_TASKS)
306326
]
307327
current_grpc_async_tasks.extend(predict_tasks)

tritony/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.11"
1+
__version__ = "0.0.12rc0"

0 commit comments

Comments
 (0)