Skip to content

Commit b2a7f19

Browse files
authored
fix: Fix L0_input_validation (#7800)
1 parent f44f3dd commit b2a7f19

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

qa/L0_input_validation/input_validation_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,23 @@ def inference_helper(model_name, batch_size=1):
221221
dummy_input_data = np.random.rand(32, 32).astype(np.float32)
222222
shape_tensor_data = np.asarray([4, 4], dtype=np.int32)
223223

224-
# Pass incorrect input byte size date for shape tensor
224+
# Pass an incorrect input byte size for the shape tensor
225225
# Use shared memory to bypass the shape check in client library
226226
input_byte_size = (shape_tensor_data.size - 1) * np.dtype(np.int32).itemsize
227227

228+
# Create a shared memory region with the incorrect byte size (input_byte_size)
228229
input_shm_handle = shm.create_shared_memory_region(
229230
"INPUT0_SHM",
230231
"/INPUT0_SHM",
231232
input_byte_size,
232233
)
234+
235+
# Write the shape tensor data into the shared memory region
236+
# Slice the data to match the incorrect byte size (input_byte_size)
233237
shm.set_shared_memory_region(
234238
input_shm_handle,
235239
[
236-
shape_tensor_data,
240+
shape_tensor_data[: input_byte_size // np.dtype(np.int32).itemsize],
237241
],
238242
)
239243
triton_client.register_system_shared_memory(

0 commit comments

Comments
 (0)