Skip to content

Commit 190a591

Browse files
authored
test: Add test for sequence flags in ensemble streaming inference (#7344)
1 parent 679a7c7 commit 190a591

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

qa/L0_simple_ensemble/ensemble_test.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -26,7 +26,13 @@
2626
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import random
2930
import sys
31+
import time
32+
from functools import partial
33+
34+
import numpy as np
35+
import tritonclient.grpc as grpcclient
3036

3137
sys.path.append("../common")
3238
sys.path.append("../clients")
@@ -40,6 +46,29 @@
4046
import tritonhttpclient
4147

4248

49+
# Utility function to Generate N requests with appropriate sequence flags
50+
class RequestGenerator:
51+
def __init__(self, init_value, num_requests) -> None:
52+
self.count = 0
53+
self.init_value = init_value
54+
self.num_requests = num_requests
55+
56+
def __enter__(self):
57+
return self
58+
59+
def __iter__(self):
60+
return self
61+
62+
def __next__(self) -> bytes:
63+
value = self.init_value + self.count
64+
if self.count == self.num_requests:
65+
raise StopIteration
66+
start = True if self.count == 0 else False
67+
end = True if self.count == self.num_requests - 1 else False
68+
self.count = self.count + 1
69+
return start, end, self.count - 1, value
70+
71+
4372
class EnsembleTest(tu.TestResultCollector):
4473
def _get_infer_count_per_version(self, model_name):
4574
triton_client = tritonhttpclient.InferenceServerClient(
@@ -102,6 +131,52 @@ def test_ensemble_add_sub_one_output(self):
102131
elif infer_count[1] == 0:
103132
self.assertTrue(False, "unexpeced zero infer count for 'simple' version 2")
104133

134+
def test_ensemble_sequence_flags(self):
135+
request_generator = RequestGenerator(0, 3)
136+
# 3 request made expect the START of 1st req to be true and
137+
# END of last request to be true
138+
expected_flags = [[True, False], [False, False], [False, True]]
139+
response_flags = []
140+
141+
def callback(start_time, result, error):
142+
response = result.get_response()
143+
arr = []
144+
arr.append(response.parameters["sequence_start"].bool_param)
145+
arr.append(response.parameters["sequence_end"].bool_param)
146+
response_flags.append(arr)
147+
148+
start_time = time.time()
149+
triton_client = grpcclient.InferenceServerClient("localhost:8001")
150+
triton_client.start_stream(callback=partial(callback, start_time))
151+
correlation_id = random.randint(1, 2**31 - 1)
152+
# create input tensors
153+
input0_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32)
154+
input1_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32)
155+
156+
inputs = [
157+
grpcclient.InferInput("INPUT0", input0_data.shape, "INT32"),
158+
grpcclient.InferInput("INPUT1", input1_data.shape, "INT32"),
159+
]
160+
161+
inputs[0].set_data_from_numpy(input0_data)
162+
inputs[1].set_data_from_numpy(input1_data)
163+
164+
# create output tensors
165+
outputs = [grpcclient.InferRequestedOutput("OUTPUT0")]
166+
for sequence_start, sequence_end, count, input_value in request_generator:
167+
triton_client.async_stream_infer(
168+
model_name="ensemble_add_sub_int32_int32_int32",
169+
inputs=inputs,
170+
outputs=outputs,
171+
request_id=f"{correlation_id}_{count}",
172+
sequence_id=correlation_id,
173+
sequence_start=sequence_start,
174+
sequence_end=sequence_end,
175+
)
176+
time.sleep(2)
177+
if expected_flags != response_flags:
178+
self.assertTrue(False, "unexpeced sequence flags mismatch error")
179+
105180

106181
if __name__ == "__main__":
107182
logging.basicConfig(stream=sys.stderr)

qa/L0_simple_ensemble/test.sh

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2+
# Copyright 2019-2024, NVIDIA CORPORATION. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -69,6 +69,30 @@ set -e
6969
kill $SERVER_PID
7070
wait $SERVER_PID
7171

72+
# Run ensemble model with sequence flags and verify response sequence
73+
run_server
74+
if [ "$SERVER_PID" == "0" ]; then
75+
echo -e "\n***\n*** Failed to start $SERVER\n***"
76+
cat $SERVER_LOG
77+
exit 1
78+
fi
79+
80+
set +e
81+
python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_sequence_flags >>$CLIENT_LOG 2>&1
82+
if [ $? -ne 0 ]; then
83+
RET=1
84+
else
85+
check_test_results $TEST_RESULT_FILE 1
86+
if [ $? -ne 0 ]; then
87+
cat $CLIENT_LOG
88+
echo -e "\n***\n*** Test Result Verification Failed\n***"
89+
RET=1
90+
fi
91+
fi
92+
set -e
93+
94+
kill $SERVER_PID
95+
wait $SERVER_PID
7296

7397
# Run ensemble model with only one output requested
7498
run_server
@@ -78,8 +102,6 @@ if [ "$SERVER_PID" == "0" ]; then
78102
exit 1
79103
fi
80104

81-
RET=0
82-
83105
set +e
84106
python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_add_sub_one_output >>$CLIENT_LOG 2>&1
85107
if [ $? -ne 0 ]; then

0 commit comments

Comments
 (0)