|
34 | 34 | import infer_util as iu |
35 | 35 | import numpy as np |
36 | 36 | import tritonclient.grpc as tritongrpcclient |
| 37 | +import tritonclient.utils.shared_memory as shm |
37 | 38 | from tritonclient.utils import InferenceServerException, np_to_triton_dtype |
38 | 39 |
|
39 | 40 |
|
@@ -211,6 +212,77 @@ def get_input_array(input_size, np_dtype): |
211 | 212 | err_str, |
212 | 213 | ) |
213 | 214 |
|
| 215 | + def test_wrong_input_shape_tensor_size(self): |
| 216 | + def inference_helper(model_name, batch_size=1): |
| 217 | + triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") |
| 218 | + if batch_size > 1: |
| 219 | + dummy_input_data = np.random.rand(batch_size, 32, 32).astype(np.float32) |
| 220 | + else: |
| 221 | + dummy_input_data = np.random.rand(32, 32).astype(np.float32) |
| 222 | + shape_tensor_data = np.asarray([4, 4], dtype=np.int32) |
| 223 | + |
| 224 | + # Pass incorrect input byte size date for shape tensor |
| 225 | + # Use shared memory to bypass the shape check in client library |
| 226 | + input_byte_size = (shape_tensor_data.size - 1) * np.dtype(np.int32).itemsize |
| 227 | + |
| 228 | + input_shm_handle = shm.create_shared_memory_region( |
| 229 | + "INPUT0_SHM", |
| 230 | + "/INPUT0_SHM", |
| 231 | + input_byte_size, |
| 232 | + ) |
| 233 | + shm.set_shared_memory_region( |
| 234 | + input_shm_handle, |
| 235 | + [ |
| 236 | + shape_tensor_data, |
| 237 | + ], |
| 238 | + ) |
| 239 | + triton_client.register_system_shared_memory( |
| 240 | + "INPUT0_SHM", |
| 241 | + "/INPUT0_SHM", |
| 242 | + input_byte_size, |
| 243 | + ) |
| 244 | + |
| 245 | + inputs = [ |
| 246 | + tritongrpcclient.InferInput( |
| 247 | + "DUMMY_INPUT0", |
| 248 | + dummy_input_data.shape, |
| 249 | + np_to_triton_dtype(np.float32), |
| 250 | + ), |
| 251 | + tritongrpcclient.InferInput( |
| 252 | + "INPUT0", |
| 253 | + shape_tensor_data.shape, |
| 254 | + np_to_triton_dtype(np.int32), |
| 255 | + ), |
| 256 | + ] |
| 257 | + inputs[0].set_data_from_numpy(dummy_input_data) |
| 258 | + inputs[1].set_shared_memory("INPUT0_SHM", input_byte_size) |
| 259 | + |
| 260 | + outputs = [ |
| 261 | + tritongrpcclient.InferRequestedOutput("DUMMY_OUTPUT0"), |
| 262 | + tritongrpcclient.InferRequestedOutput("OUTPUT0"), |
| 263 | + ] |
| 264 | + |
| 265 | + try: |
| 266 | + # Perform inference |
| 267 | + with self.assertRaises(InferenceServerException) as e: |
| 268 | + triton_client.infer( |
| 269 | + model_name=model_name, inputs=inputs, outputs=outputs |
| 270 | + ) |
| 271 | + err_str = str(e.exception) |
| 272 | + correct_input_byte_size = ( |
| 273 | + shape_tensor_data.size * np.dtype(np.int32).itemsize |
| 274 | + ) |
| 275 | + self.assertIn( |
| 276 | + f"input byte size mismatch for input 'INPUT0' for model '{model_name}'. Expected {correct_input_byte_size}, got {input_byte_size}", |
| 277 | + err_str, |
| 278 | + ) |
| 279 | + finally: |
| 280 | + shm.destroy_shared_memory_region(input_shm_handle) |
| 281 | + triton_client.unregister_system_shared_memory("INPUT0_SHM") |
| 282 | + |
| 283 | + inference_helper(model_name="plan_nobatch_zero_1_float32_int32") |
| 284 | + inference_helper(model_name="plan_zero_1_float32_int32", batch_size=8) |
| 285 | + |
214 | 286 |
|
215 | 287 | if __name__ == "__main__": |
216 | 288 | unittest.main() |
0 commit comments