Skip to content

Commit 36114c4

Browse files
committed
Add client tests
1 parent 52bb23f commit 36114c4

File tree

3 files changed

+286
-130
lines changed

3 files changed

+286
-130
lines changed

Dockerfile.QA

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ RUN mkdir -p qa/common && \
149149
cp bin/triton_json_test qa/L0_json/. && \
150150
cp bin/backend_output_detail_test qa/L0_backend_output_detail/. && \
151151
cp -r deploy/mlflow-triton-plugin qa/L0_mlflow/. && \
152-
cp bin/input_byte_size_test qa/L0_input_validation/. && \
153-
cp -r docs/examples/model_repository/simple_identity qa/L0_input_validation/models
152+
cp -r docs/examples/model_repository/{simple,simple_identity,simple_string} qa/L0_input_validation/models && \
153+
cp bin/input_byte_size_test qa/L0_input_validation/.
154154

155155
RUN mkdir -p qa/pkgs && \
156156
cp python/triton*.whl qa/pkgs/. && \

qa/L0_input_validation/input_validation_test.py

Lines changed: 270 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
import infer_util as iu
3535
import numpy as np
3636
import tritonclient.grpc as tritongrpcclient
37-
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
37+
import tritonclient.http as tritonhttpclient
38+
import tritonclient.utils as utils
39+
from tritonclient.utils import (
40+
InferenceServerException,
41+
np_to_triton_dtype,
42+
shared_memory,
43+
)
3844

3945

4046
class InputValTest(unittest.TestCase):
@@ -115,101 +121,283 @@ def test_input_validation_all_optional(self):
115121

116122

117123
class InputShapeTest(unittest.TestCase):
118-
def test_input_shape_validation(self):
119-
input_size = 8
120-
model_name = "pt_identity"
121-
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
124+
def test_client_input_shape_validation(self):
125+
model_name = "simple"
122126

123-
# Pass
124-
input_data = np.arange(input_size)[None].astype(np.float32)
125-
inputs = [
126-
tritongrpcclient.InferInput(
127-
"INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype)
128-
)
129-
]
130-
inputs[0].set_data_from_numpy(input_data)
131-
triton_client.infer(model_name=model_name, inputs=inputs)
132-
133-
# Larger input byte size than expected
134-
input_data = np.arange(input_size + 2)[None].astype(np.float32)
135-
inputs = [
136-
tritongrpcclient.InferInput(
137-
"INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype)
138-
)
139-
]
140-
inputs[0].set_data_from_numpy(input_data)
141-
# Compromised input shape
142-
inputs[0].set_shape((1, input_size))
143-
with self.assertRaises(InferenceServerException) as e:
144-
triton_client.infer(
145-
model_name=model_name,
146-
inputs=inputs,
127+
for client_type in ["http", "grpc"]:
128+
if client_type == "http":
129+
triton_client = tritonhttpclient.InferenceServerClient("localhost:8000")
130+
else:
131+
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
132+
133+
# Infer
134+
inputs = []
135+
if client_type == "http":
136+
inputs.append(tritonhttpclient.InferInput("INPUT0", [1, 16], "INT32"))
137+
inputs.append(tritonhttpclient.InferInput("INPUT1", [1, 16], "INT32"))
138+
else:
139+
inputs.append(tritongrpcclient.InferInput("INPUT0", [1, 16], "INT32"))
140+
inputs.append(tritongrpcclient.InferInput("INPUT1", [1, 16], "INT32"))
141+
142+
# Create the data for the two input tensors. Initialize the first
143+
# to unique integers and the second to all ones.
144+
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
145+
input0_data = np.expand_dims(input0_data, axis=0)
146+
input1_data = np.ones(shape=(1, 16), dtype=np.int32)
147+
148+
# Initialize the data
149+
inputs[0].set_data_from_numpy(input0_data)
150+
inputs[1].set_data_from_numpy(input1_data)
151+
152+
# Compromised input shapes
153+
inputs[0].set_shape([2, 8])
154+
inputs[1].set_shape([2, 8])
155+
156+
with self.assertRaises(InferenceServerException) as e:
157+
triton_client.infer(model_name=model_name, inputs=inputs)
158+
err_str = str(e.exception)
159+
self.assertIn(
160+
f"unexpected shape for input 'INPUT1' for model 'simple'. Expected [-1,16], got [2,8]",
161+
err_str,
147162
)
148-
err_str = str(e.exception)
149-
self.assertIn(
150-
"input byte size mismatch for input 'INPUT0' for model 'pt_identity'. Expected 32, got 40",
151-
err_str,
152-
)
153163

154-
def test_input_string_shape_validation(self):
155-
input_size = 16
156-
model_name = "graphdef_object_int32_int32"
157-
np_dtype_string = np.dtype(object)
158-
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
164+
# Compromised input shapes
165+
inputs[0].set_shape([1, 8])
166+
inputs[1].set_shape([1, 8])
159167

160-
def get_input_array(input_size, np_dtype):
161-
rinput_dtype = iu._range_repr_dtype(np_dtype)
162-
input_array = np.random.randint(
163-
low=0, high=127, size=(1, input_size), dtype=rinput_dtype
168+
with self.assertRaises(InferenceServerException) as e:
169+
triton_client.infer(model_name=model_name, inputs=inputs)
170+
err_str = str(e.exception)
171+
self.assertIn(
172+
f"'INPUT0' got unexpected elements count 16, expected 8",
173+
err_str,
164174
)
165175

166-
# Convert to string type
167-
inn = np.array(
168-
[str(x) for x in input_array.reshape(input_array.size)], dtype=object
176+
def test_client_input_string_shape_validation(self):
177+
for client_type in ["http", "grpc"]:
178+
179+
def identity_inference(triton_client, np_array, binary_data):
180+
model_name = "simple_identity"
181+
182+
# Total elements no change
183+
inputs = []
184+
if client_type == "http":
185+
inputs.append(
186+
tritonhttpclient.InferInput("INPUT0", np_array.shape, "BYTES")
187+
)
188+
inputs[0].set_data_from_numpy(np_array, binary_data=binary_data)
189+
inputs[0].set_shape([2, 8])
190+
else:
191+
inputs.append(
192+
tritongrpcclient.InferInput("INPUT0", np_array.shape, "BYTES")
193+
)
194+
inputs[0].set_data_from_numpy(np_array)
195+
inputs[0].set_shape([2, 8])
196+
triton_client.infer(model_name=model_name, inputs=inputs)
197+
198+
# Compromised input shape
199+
inputs[0].set_shape([1, 8])
200+
201+
with self.assertRaises(InferenceServerException) as e:
202+
triton_client.infer(model_name=model_name, inputs=inputs)
203+
err_str = str(e.exception)
204+
self.assertIn(
205+
f"'INPUT0' got unexpected elements count 16, expected 8",
206+
err_str,
207+
)
208+
209+
if client_type == "http":
210+
triton_client = tritonhttpclient.InferenceServerClient("localhost:8000")
211+
else:
212+
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
213+
214+
# Example using BYTES input tensor with utf-8 encoded string that
215+
# has an embedded null character.
216+
null_chars_array = np.array(
217+
["he\x00llo".encode("utf-8") for i in range(16)], dtype=np.object_
169218
)
170-
input_array = inn.reshape(input_array.shape)
219+
null_char_data = null_chars_array.reshape([1, 16])
220+
identity_inference(triton_client, null_char_data, True) # Using binary data
221+
identity_inference(triton_client, null_char_data, False) # Using JSON data
222+
223+
# Example using BYTES input tensor with 16 elements, where each
224+
# element is a 4-byte binary blob with value 0x00010203. Can use
225+
# dtype=np.bytes_ in this case.
226+
bytes_data = [b"\x00\x01\x02\x03" for i in range(16)]
227+
np_bytes_data = np.array(bytes_data, dtype=np.bytes_)
228+
np_bytes_data = np_bytes_data.reshape([1, 16])
229+
identity_inference(triton_client, np_bytes_data, True) # Using binary data
230+
identity_inference(triton_client, np_bytes_data, False) # Using JSON data
231+
232+
def test_client_input_shm_size_validation(self):
233+
# We use a simple model that takes 2 input tensors of 16 integers
234+
# each and returns 2 output tensors of 16 integers each. One
235+
# output tensor is the element-wise sum of the inputs and one
236+
# output is the element-wise difference.
237+
model_name = "simple"
238+
239+
for client_type in ["http", "grpc"]:
240+
if client_type == "http":
241+
triton_client = tritonhttpclient.InferenceServerClient("localhost:8000")
242+
else:
243+
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
244+
# To make sure no shared memory regions are registered with the
245+
# server.
246+
triton_client.unregister_system_shared_memory()
247+
triton_client.unregister_cuda_shared_memory()
248+
249+
# Create the data for the two input tensors. Initialize the first
250+
# to unique integers and the second to all ones.
251+
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
252+
input1_data = np.ones(shape=16, dtype=np.int32)
253+
254+
input_byte_size = input0_data.size * input0_data.itemsize
171255

256+
# Create shared memory region for input and store shared memory handle
257+
shm_ip_handle = shared_memory.create_shared_memory_region(
258+
"input_data", "/input_simple", input_byte_size * 2
259+
)
260+
261+
# Put input data values into shared memory
262+
shared_memory.set_shared_memory_region(shm_ip_handle, [input0_data])
263+
shared_memory.set_shared_memory_region(
264+
shm_ip_handle, [input1_data], offset=input_byte_size
265+
)
266+
267+
# Register shared memory region for inputs with Triton Server
268+
triton_client.register_system_shared_memory(
269+
"input_data", "/input_simple", input_byte_size * 2
270+
)
271+
272+
# Set the parameters to use data from shared memory
172273
inputs = []
173-
inputs.append(
174-
tritongrpcclient.InferInput(
175-
"INPUT0", input_array.shape, np_to_triton_dtype(np_dtype)
176-
)
274+
if client_type == "http":
275+
inputs.append(tritonhttpclient.InferInput("INPUT0", [1, 16], "INT32"))
276+
inputs.append(tritonhttpclient.InferInput("INPUT1", [1, 16], "INT32"))
277+
else:
278+
inputs.append(tritongrpcclient.InferInput("INPUT0", [1, 16], "INT32"))
279+
inputs.append(tritongrpcclient.InferInput("INPUT1", [1, 16], "INT32"))
280+
inputs[-2].set_shared_memory("input_data", input_byte_size + 4)
281+
inputs[-1].set_shared_memory(
282+
"input_data", input_byte_size, offset=input_byte_size
177283
)
178-
inputs.append(
179-
tritongrpcclient.InferInput(
180-
"INPUT1", input_array.shape, np_to_triton_dtype(np_dtype)
181-
)
284+
285+
with self.assertRaises(InferenceServerException) as e:
286+
triton_client.infer(model_name=model_name, inputs=inputs)
287+
err_str = str(e.exception)
288+
self.assertIn(
289+
f"'INPUT0' got unexpected byte size {input_byte_size+4}, expected {input_byte_size}",
290+
err_str,
182291
)
183292

184-
inputs[0].set_data_from_numpy(input_array)
185-
inputs[1].set_data_from_numpy(input_array)
186-
return inputs
293+
# Set the parameters to use data from shared memory
294+
inputs[-2].set_shared_memory("input_data", input_byte_size)
295+
inputs[-1].set_shared_memory(
296+
"input_data", input_byte_size - 4, offset=input_byte_size
297+
)
187298

188-
# Input size is less than expected
189-
inputs = get_input_array(input_size - 2, np_dtype_string)
190-
# Compromised input shape
191-
inputs[0].set_shape((1, input_size))
192-
inputs[1].set_shape((1, input_size))
193-
with self.assertRaises(InferenceServerException) as e:
194-
triton_client.infer(model_name=model_name, inputs=inputs)
195-
err_str = str(e.exception)
196-
self.assertIn(
197-
f"expected {input_size} string elements for inference input 'INPUT1', got {input_size-2}",
198-
err_str,
199-
)
299+
with self.assertRaises(InferenceServerException) as e:
300+
triton_client.infer(model_name=model_name, inputs=inputs)
301+
err_str = str(e.exception)
302+
self.assertIn(
303+
f"'INPUT1' got unexpected byte size {input_byte_size-4}, expected {input_byte_size}",
304+
err_str,
305+
)
200306

201-
# Input size is greater than expected
202-
inputs = get_input_array(input_size + 2, np_dtype_string)
203-
# Compromised input shape
204-
inputs[0].set_shape((1, input_size))
205-
inputs[1].set_shape((1, input_size))
206-
with self.assertRaises(InferenceServerException) as e:
207-
triton_client.infer(model_name=model_name, inputs=inputs)
208-
err_str = str(e.exception)
209-
self.assertIn(
210-
f"expected {input_size} string elements for inference input 'INPUT1', got {input_size+2}",
211-
err_str,
212-
)
307+
print(triton_client.get_system_shared_memory_status())
308+
triton_client.unregister_system_shared_memory()
309+
assert len(shared_memory.mapped_shared_memory_regions()) == 1
310+
shared_memory.destroy_shared_memory_region(shm_ip_handle)
311+
assert len(shared_memory.mapped_shared_memory_regions()) == 0
312+
313+
def test_client_input_string_shm_size_validation(self):
314+
# We use a simple model that takes 2 input tensors of 16 strings
315+
# each and returns 2 output tensors of 16 strings each. The input
316+
# strings must represent integers. One output tensor is the
317+
# element-wise sum of the inputs and one output is the element-wise
318+
# difference.
319+
model_name = "simple_string"
320+
321+
for client_type in ["http", "grpc"]:
322+
if client_type == "http":
323+
triton_client = tritonhttpclient.InferenceServerClient("localhost:8000")
324+
else:
325+
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
326+
327+
# To make sure no shared memory regions are registered with the
328+
# server.
329+
triton_client.unregister_system_shared_memory()
330+
triton_client.unregister_cuda_shared_memory()
331+
332+
# Create the data for the two input tensors. Initialize the first
333+
# to unique integers and the second to all ones.
334+
in0 = np.arange(start=0, stop=16, dtype=np.int32)
335+
in0n = np.array(
336+
[str(x).encode("utf-8") for x in in0.flatten()], dtype=object
337+
)
338+
input0_data = in0n.reshape(in0.shape)
339+
in1 = np.ones(shape=16, dtype=np.int32)
340+
in1n = np.array(
341+
[str(x).encode("utf-8") for x in in1.flatten()], dtype=object
342+
)
343+
input1_data = in1n.reshape(in1.shape)
344+
345+
input0_data_serialized = utils.serialize_byte_tensor(input0_data)
346+
input1_data_serialized = utils.serialize_byte_tensor(input1_data)
347+
input0_byte_size = utils.serialized_byte_size(input0_data_serialized)
348+
input1_byte_size = utils.serialized_byte_size(input1_data_serialized)
349+
350+
# Create Input0 and Input1 in Shared Memory and store shared memory handles
351+
shm_ip0_handle = shared_memory.create_shared_memory_region(
352+
"input0_data", "/input0_simple", input0_byte_size
353+
)
354+
shm_ip1_handle = shared_memory.create_shared_memory_region(
355+
"input1_data", "/input1_simple", input1_byte_size
356+
)
357+
358+
# Put input data values into shared memory
359+
shared_memory.set_shared_memory_region(
360+
shm_ip0_handle, [input0_data_serialized]
361+
)
362+
shared_memory.set_shared_memory_region(
363+
shm_ip1_handle, [input1_data_serialized]
364+
)
365+
366+
# Register Input0 and Input1 shared memory with Triton Server
367+
triton_client.register_system_shared_memory(
368+
"input0_data", "/input0_simple", input0_byte_size
369+
)
370+
triton_client.register_system_shared_memory(
371+
"input1_data", "/input1_simple", input1_byte_size
372+
)
373+
374+
# Set the parameters to use data from shared memory
375+
inputs = []
376+
if client_type == "http":
377+
inputs.append(tritonhttpclient.InferInput("INPUT0", [1, 16], "BYTES"))
378+
inputs.append(tritonhttpclient.InferInput("INPUT1", [1, 16], "BYTES"))
379+
else:
380+
inputs.append(tritongrpcclient.InferInput("INPUT0", [1, 16], "BYTES"))
381+
inputs.append(tritongrpcclient.InferInput("INPUT1", [1, 16], "BYTES"))
382+
inputs[-2].set_shared_memory("input0_data", input0_byte_size + 4)
383+
inputs[-1].set_shared_memory("input1_data", input1_byte_size)
384+
385+
with self.assertRaises(InferenceServerException) as e:
386+
triton_client.infer(model_name=model_name, inputs=inputs)
387+
err_str = str(e.exception)
388+
389+
# BYTES inputs in shared memory will skip the check at the client
390+
self.assertIn(
391+
f"Invalid offset + byte size for shared memory region: 'input0_data'",
392+
err_str,
393+
)
394+
395+
print(triton_client.get_system_shared_memory_status())
396+
triton_client.unregister_system_shared_memory()
397+
assert len(shared_memory.mapped_shared_memory_regions()) == 2
398+
shared_memory.destroy_shared_memory_region(shm_ip0_handle)
399+
shared_memory.destroy_shared_memory_region(shm_ip1_handle)
400+
assert len(shared_memory.mapped_shared_memory_regions()) == 0
213401

214402

215403
if __name__ == "__main__":

0 commit comments

Comments
 (0)