|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 |
|
3 | | -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | # |
5 | 5 | # Redistribution and use in source and binary forms, with or without |
6 | 6 | # modification, are permitted provided that the following conditions |
|
26 | 26 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
28 | 28 |
|
| 29 | +import random |
29 | 30 | import sys |
| 31 | +import time |
| 32 | +from functools import partial |
| 33 | + |
| 34 | +import numpy as np |
| 35 | +import tritonclient.grpc as grpcclient |
30 | 36 |
|
31 | 37 | sys.path.append("../common") |
32 | 38 | sys.path.append("../clients") |
|
40 | 46 | import tritonhttpclient |
41 | 47 |
|
42 | 48 |
|
| 49 | +# Utility function to Generate N requests with appropriate sequence flags |
| 50 | +class RequestGenerator: |
| 51 | + def __init__(self, init_value, num_requests) -> None: |
| 52 | + self.count = 0 |
| 53 | + self.init_value = init_value |
| 54 | + self.num_requests = num_requests |
| 55 | + |
| 56 | + def __enter__(self): |
| 57 | + return self |
| 58 | + |
| 59 | + def __iter__(self): |
| 60 | + return self |
| 61 | + |
| 62 | + def __next__(self) -> bytes: |
| 63 | + value = self.init_value + self.count |
| 64 | + if self.count == self.num_requests: |
| 65 | + raise StopIteration |
| 66 | + start = True if self.count == 0 else False |
| 67 | + end = True if self.count == self.num_requests - 1 else False |
| 68 | + self.count = self.count + 1 |
| 69 | + return start, end, self.count - 1, value |
| 70 | + |
| 71 | + |
43 | 72 | class EnsembleTest(tu.TestResultCollector): |
44 | 73 | def _get_infer_count_per_version(self, model_name): |
45 | 74 | triton_client = tritonhttpclient.InferenceServerClient( |
@@ -102,6 +131,52 @@ def test_ensemble_add_sub_one_output(self): |
102 | 131 | elif infer_count[1] == 0: |
103 | 132 | self.assertTrue(False, "unexpeced zero infer count for 'simple' version 2") |
104 | 133 |
|
| 134 | + def test_ensemble_sequence_flags(self): |
| 135 | + request_generator = RequestGenerator(0, 3) |
| 136 | + # 3 request made expect the START of 1st req to be true and |
| 137 | + # END of last request to be true |
| 138 | + expected_flags = [[True, False], [False, False], [False, True]] |
| 139 | + response_flags = [] |
| 140 | + |
| 141 | + def callback(start_time, result, error): |
| 142 | + response = result.get_response() |
| 143 | + arr = [] |
| 144 | + arr.append(response.parameters["sequence_start"].bool_param) |
| 145 | + arr.append(response.parameters["sequence_end"].bool_param) |
| 146 | + response_flags.append(arr) |
| 147 | + |
| 148 | + start_time = time.time() |
| 149 | + triton_client = grpcclient.InferenceServerClient("localhost:8001") |
| 150 | + triton_client.start_stream(callback=partial(callback, start_time)) |
| 151 | + correlation_id = random.randint(1, 2**31 - 1) |
| 152 | + # create input tensors |
| 153 | + input0_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32) |
| 154 | + input1_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32) |
| 155 | + |
| 156 | + inputs = [ |
| 157 | + grpcclient.InferInput("INPUT0", input0_data.shape, "INT32"), |
| 158 | + grpcclient.InferInput("INPUT1", input1_data.shape, "INT32"), |
| 159 | + ] |
| 160 | + |
| 161 | + inputs[0].set_data_from_numpy(input0_data) |
| 162 | + inputs[1].set_data_from_numpy(input1_data) |
| 163 | + |
| 164 | + # create output tensors |
| 165 | + outputs = [grpcclient.InferRequestedOutput("OUTPUT0")] |
| 166 | + for sequence_start, sequence_end, count, input_value in request_generator: |
| 167 | + triton_client.async_stream_infer( |
| 168 | + model_name="ensemble_add_sub_int32_int32_int32", |
| 169 | + inputs=inputs, |
| 170 | + outputs=outputs, |
| 171 | + request_id=f"{correlation_id}_{count}", |
| 172 | + sequence_id=correlation_id, |
| 173 | + sequence_start=sequence_start, |
| 174 | + sequence_end=sequence_end, |
| 175 | + ) |
| 176 | + time.sleep(2) |
| 177 | + if expected_flags != response_flags: |
| 178 | + self.assertTrue(False, "unexpeced sequence flags mismatch error") |
| 179 | + |
105 | 180 |
|
106 | 181 | if __name__ == "__main__": |
107 | 182 | logging.basicConfig(stream=sys.stderr) |
|
0 commit comments