|
34 | 34 | import infer_util as iu |
35 | 35 | import numpy as np |
36 | 36 | import tritonclient.grpc as tritongrpcclient |
| 37 | +import tritonclient.http as tritonhttpclient |
| 38 | +import tritonclient.utils as utils |
37 | 39 | import tritonclient.utils.shared_memory as shm |
38 | | -from tritonclient.utils import InferenceServerException, np_to_triton_dtype |
| 40 | +from tritonclient.utils import InferenceServerException |
39 | 41 |
|
40 | 42 |
|
41 | 43 | class InputValTest(unittest.TestCase): |
@@ -116,101 +118,113 @@ def test_input_validation_all_optional(self): |
116 | 118 |
|
117 | 119 |
|
118 | 120 | class InputShapeTest(unittest.TestCase): |
119 | | - def test_input_shape_validation(self): |
120 | | - input_size = 8 |
121 | | - model_name = "pt_identity" |
122 | | - triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") |
| 121 | + def test_client_input_shape_validation(self): |
| 122 | + model_name = "simple" |
123 | 123 |
|
124 | | - # Pass |
125 | | - input_data = np.arange(input_size)[None].astype(np.float32) |
126 | | - inputs = [ |
127 | | - tritongrpcclient.InferInput( |
128 | | - "INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype) |
129 | | - ) |
130 | | - ] |
131 | | - inputs[0].set_data_from_numpy(input_data) |
132 | | - triton_client.infer(model_name=model_name, inputs=inputs) |
133 | | - |
134 | | - # Larger input byte size than expected |
135 | | - input_data = np.arange(input_size + 2)[None].astype(np.float32) |
136 | | - inputs = [ |
137 | | - tritongrpcclient.InferInput( |
138 | | - "INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype) |
139 | | - ) |
140 | | - ] |
141 | | - inputs[0].set_data_from_numpy(input_data) |
142 | | - # Compromised input shape |
143 | | - inputs[0].set_shape((1, input_size)) |
144 | | - with self.assertRaises(InferenceServerException) as e: |
145 | | - triton_client.infer( |
146 | | - model_name=model_name, |
147 | | - inputs=inputs, |
| 124 | + for client_type in ["http", "grpc"]: |
| 125 | + if client_type == "http": |
| 126 | + triton_client = tritonhttpclient.InferenceServerClient("localhost:8000") |
| 127 | + else: |
| 128 | + triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") |
| 129 | + |
| 130 | + # Infer |
| 131 | + inputs = [] |
| 132 | + if client_type == "http": |
| 133 | + inputs.append(tritonhttpclient.InferInput("INPUT0", [1, 16], "INT32")) |
| 134 | + inputs.append(tritonhttpclient.InferInput("INPUT1", [1, 16], "INT32")) |
| 135 | + else: |
| 136 | + inputs.append(tritongrpcclient.InferInput("INPUT0", [1, 16], "INT32")) |
| 137 | + inputs.append(tritongrpcclient.InferInput("INPUT1", [1, 16], "INT32")) |
| 138 | + |
| 139 | + # Create the data for the two input tensors. Initialize the first |
| 140 | + # to unique integers and the second to all ones. |
| 141 | + input0_data = np.arange(start=0, stop=16, dtype=np.int32) |
| 142 | + input0_data = np.expand_dims(input0_data, axis=0) |
| 143 | + input1_data = np.ones(shape=(1, 16), dtype=np.int32) |
| 144 | + |
| 145 | + # Initialize the data |
| 146 | + inputs[0].set_data_from_numpy(input0_data) |
| 147 | + inputs[1].set_data_from_numpy(input1_data) |
| 148 | + |
| 149 | + # Compromised input shapes |
| 150 | + inputs[0].set_shape([2, 8]) |
| 151 | + inputs[1].set_shape([2, 8]) |
| 152 | + |
| 153 | + with self.assertRaises(InferenceServerException) as e: |
| 154 | + triton_client.infer(model_name=model_name, inputs=inputs) |
| 155 | + err_str = str(e.exception) |
| 156 | + self.assertIn( |
| 157 | + f"unexpected shape for input 'INPUT1' for model 'simple'. Expected [-1,16], got [2,8]", |
| 158 | + err_str, |
148 | 159 | ) |
149 | | - err_str = str(e.exception) |
150 | | - self.assertIn( |
151 | | - "input byte size mismatch for input 'INPUT0' for model 'pt_identity'. Expected 32, got 40", |
152 | | - err_str, |
153 | | - ) |
154 | 160 |
|
155 | | - def test_input_string_shape_validation(self): |
156 | | - input_size = 16 |
157 | | - model_name = "graphdef_object_int32_int32" |
158 | | - np_dtype_string = np.dtype(object) |
159 | | - triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") |
| 161 | + # Compromised input shapes |
| 162 | + inputs[0].set_shape([1, 8]) |
| 163 | + inputs[1].set_shape([1, 8]) |
160 | 164 |
|
161 | | - def get_input_array(input_size, np_dtype): |
162 | | - rinput_dtype = iu._range_repr_dtype(np_dtype) |
163 | | - input_array = np.random.randint( |
164 | | - low=0, high=127, size=(1, input_size), dtype=rinput_dtype |
| 165 | + with self.assertRaises(InferenceServerException) as e: |
| 166 | + triton_client.infer(model_name=model_name, inputs=inputs) |
| 167 | + err_str = str(e.exception) |
| 168 | + self.assertIn( |
| 169 | + f"input 'INPUT0' got unexpected elements count 16, expected 8", |
| 170 | + err_str, |
165 | 171 | ) |
166 | 172 |
|
167 | | - # Convert to string type |
168 | | - inn = np.array( |
169 | | - [str(x) for x in input_array.reshape(input_array.size)], dtype=object |
170 | | - ) |
171 | | - input_array = inn.reshape(input_array.shape) |
| 173 | + def test_client_input_string_shape_validation(self): |
| 174 | + for client_type in ["http", "grpc"]: |
172 | 175 |
|
173 | | - inputs = [] |
174 | | - inputs.append( |
175 | | - tritongrpcclient.InferInput( |
176 | | - "INPUT0", input_array.shape, np_to_triton_dtype(np_dtype) |
177 | | - ) |
178 | | - ) |
179 | | - inputs.append( |
180 | | - tritongrpcclient.InferInput( |
181 | | - "INPUT1", input_array.shape, np_to_triton_dtype(np_dtype) |
182 | | - ) |
183 | | - ) |
| 176 | + def identity_inference(triton_client, np_array, binary_data): |
| 177 | + model_name = "simple_identity" |
184 | 178 |
|
185 | | - inputs[0].set_data_from_numpy(input_array) |
186 | | - inputs[1].set_data_from_numpy(input_array) |
187 | | - return inputs |
| 179 | + # Total elements no change |
| 180 | + inputs = [] |
| 181 | + if client_type == "http": |
| 182 | + inputs.append( |
| 183 | + tritonhttpclient.InferInput("INPUT0", np_array.shape, "BYTES") |
| 184 | + ) |
| 185 | + inputs[0].set_data_from_numpy(np_array, binary_data=binary_data) |
| 186 | + inputs[0].set_shape([2, 8]) |
| 187 | + else: |
| 188 | + inputs.append( |
| 189 | + tritongrpcclient.InferInput("INPUT0", np_array.shape, "BYTES") |
| 190 | + ) |
| 191 | + inputs[0].set_data_from_numpy(np_array) |
| 192 | + inputs[0].set_shape([2, 8]) |
| 193 | + triton_client.infer(model_name=model_name, inputs=inputs) |
188 | 194 |
|
189 | | - # Input size is less than expected |
190 | | - inputs = get_input_array(input_size - 2, np_dtype_string) |
191 | | - # Compromised input shape |
192 | | - inputs[0].set_shape((1, input_size)) |
193 | | - inputs[1].set_shape((1, input_size)) |
194 | | - with self.assertRaises(InferenceServerException) as e: |
195 | | - triton_client.infer(model_name=model_name, inputs=inputs) |
196 | | - err_str = str(e.exception) |
197 | | - self.assertIn( |
198 | | - f"expected {input_size} string elements for inference input 'INPUT1', got {input_size-2}", |
199 | | - err_str, |
200 | | - ) |
| 195 | + # Compromised input shape |
| 196 | + inputs[0].set_shape([1, 8]) |
201 | 197 |
|
202 | | - # Input size is greater than expected |
203 | | - inputs = get_input_array(input_size + 2, np_dtype_string) |
204 | | - # Compromised input shape |
205 | | - inputs[0].set_shape((1, input_size)) |
206 | | - inputs[1].set_shape((1, input_size)) |
207 | | - with self.assertRaises(InferenceServerException) as e: |
208 | | - triton_client.infer(model_name=model_name, inputs=inputs) |
209 | | - err_str = str(e.exception) |
210 | | - self.assertIn( |
211 | | - f"expected {input_size} string elements for inference input 'INPUT1', got {input_size+2}", |
212 | | - err_str, |
213 | | - ) |
| 198 | + with self.assertRaises(InferenceServerException) as e: |
| 199 | + triton_client.infer(model_name=model_name, inputs=inputs) |
| 200 | + err_str = str(e.exception) |
| 201 | + self.assertIn( |
| 202 | + f"input 'INPUT0' got unexpected elements count 16, expected 8", |
| 203 | + err_str, |
| 204 | + ) |
| 205 | + |
| 206 | + if client_type == "http": |
| 207 | + triton_client = tritonhttpclient.InferenceServerClient("localhost:8000") |
| 208 | + else: |
| 209 | + triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") |
| 210 | + |
| 211 | + # Example using BYTES input tensor with utf-8 encoded string that |
| 212 | + # has an embedded null character. |
| 213 | + null_chars_array = np.array( |
| 214 | + ["he\x00llo".encode("utf-8") for i in range(16)], dtype=np.object_ |
| 215 | + ) |
| 216 | + null_char_data = null_chars_array.reshape([1, 16]) |
| 217 | + identity_inference(triton_client, null_char_data, True) # Using binary data |
| 218 | + identity_inference(triton_client, null_char_data, False) # Using JSON data |
| 219 | + |
| 220 | + # Example using BYTES input tensor with 16 elements, where each |
| 221 | + # element is a 4-byte binary blob with value 0x00010203. Can use |
| 222 | + # dtype=np.bytes_ in this case. |
| 223 | + bytes_data = [b"\x00\x01\x02\x03" for i in range(16)] |
| 224 | + np_bytes_data = np.array(bytes_data, dtype=np.bytes_) |
| 225 | + np_bytes_data = np_bytes_data.reshape([1, 16]) |
| 226 | + identity_inference(triton_client, np_bytes_data, True) # Using binary data |
| 227 | + identity_inference(triton_client, np_bytes_data, False) # Using JSON data |
214 | 228 |
|
215 | 229 | def test_wrong_input_shape_tensor_size(self): |
216 | 230 | def inference_helper(model_name, batch_size=1): |
@@ -246,12 +260,12 @@ def inference_helper(model_name, batch_size=1): |
246 | 260 | tritongrpcclient.InferInput( |
247 | 261 | "DUMMY_INPUT0", |
248 | 262 | dummy_input_data.shape, |
249 | | - np_to_triton_dtype(np.float32), |
| 263 | + utils.np_to_triton_dtype(np.float32), |
250 | 264 | ), |
251 | 265 | tritongrpcclient.InferInput( |
252 | 266 | "INPUT0", |
253 | 267 | shape_tensor_data.shape, |
254 | | - np_to_triton_dtype(np.int32), |
| 268 | + utils.np_to_triton_dtype(np.int32), |
255 | 269 | ), |
256 | 270 | ] |
257 | 271 | inputs[0].set_data_from_numpy(dummy_input_data) |
|
0 commit comments