Skip to content

Commit f589b8a

Browse files
authored
chore: ensure proper clean up in shared memory related tests (#7729)
1 parent 1938df8 commit f589b8a

File tree

2 files changed

+347
-323
lines changed

2 files changed

+347
-323
lines changed

qa/L0_shared_memory/shared_memory_test.py

Lines changed: 61 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class SystemSharedMemoryTestBase(tu.TestResultCollector):
5050

5151
def setUp(self):
5252
self._setup_client()
53+
self._shm_handles = []
54+
55+
def tearDown(self):
56+
self._cleanup_shm_handles()
5357

5458
def _setup_client(self):
5559
self.protocol = os.environ.get("CLIENT_TYPE", "http")
@@ -89,6 +93,7 @@ def _configure_server(
8993
Offset into the shared memory object to start the registered region.
9094
9195
"""
96+
self._cleanup_shm_handles()
9297
shm_ip0_handle = shm.create_shared_memory_region(
9398
"input0_data", "/input0_data", create_byte_size
9499
)
@@ -101,6 +106,12 @@ def _configure_server(
101106
shm_op1_handle = shm.create_shared_memory_region(
102107
"output1_data", "/output1_data", create_byte_size
103108
)
109+
self._shm_handles = [
110+
shm_ip0_handle,
111+
shm_ip1_handle,
112+
shm_op0_handle,
113+
shm_op1_handle,
114+
]
104115
# Implicit assumption that input and output byte_sizes are 64 bytes for now
105116
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
106117
input1_data = np.ones(shape=16, dtype=np.int32)
@@ -118,23 +129,21 @@ def _configure_server(
118129
self.triton_client.register_system_shared_memory(
119130
"output1_data", "/output1_data", register_byte_size, offset=register_offset
120131
)
121-
return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle]
122132

123-
def _cleanup_server(self, shm_handles):
124-
for shm_handle in shm_handles:
133+
def _cleanup_shm_handles(self):
134+
for shm_handle in self._shm_handles:
125135
shm.destroy_shared_memory_region(shm_handle)
136+
self._shm_handles = []
126137

127138

128139
class SharedMemoryTest(SystemSharedMemoryTestBase):
129140
def test_invalid_create_shm(self):
130-
# Raises error since tried to create invalid system shared memory region
131-
try:
132-
shm_op0_handle = shm.create_shared_memory_region(
133-
"dummy_data", "/dummy_data", -1
141+
with self.assertRaisesRegex(
142+
shm.SharedMemoryException, "unable to create the shared memory region"
143+
):
144+
self._shm_handles.append(
145+
shm.create_shared_memory_region("dummy_data", "/dummy_data", -1)
134146
)
135-
shm.destroy_shared_memory_region(shm_op0_handle)
136-
except Exception as ex:
137-
self.assertTrue(str(ex) == "unable to initialize the size")
138147

139148
def test_valid_create_set_register(self):
140149
# Create a valid system shared memory region, fill data in it and register
@@ -195,14 +204,14 @@ def test_reregister_after_register(self):
195204
def test_unregister_after_inference(self):
196205
# Unregister after inference
197206
error_msg = []
198-
shm_handles = self._configure_server()
207+
self._configure_server()
199208
iu.shm_basic_infer(
200209
self,
201210
self.triton_client,
202-
shm_handles[0],
203-
shm_handles[1],
204-
shm_handles[2],
205-
shm_handles[3],
211+
self._shm_handles[0],
212+
self._shm_handles[1],
213+
self._shm_handles[2],
214+
self._shm_handles[3],
206215
error_msg,
207216
protocol=self.protocol,
208217
use_system_shared_memory=True,
@@ -215,20 +224,20 @@ def test_unregister_after_inference(self):
215224
self.assertTrue(len(shm_status) == 3)
216225
else:
217226
self.assertTrue(len(shm_status.regions) == 3)
218-
self._cleanup_server(shm_handles)
227+
self._cleanup_shm_handles()
219228

220229
def test_register_after_inference(self):
221230
# Register after inference
222231
error_msg = []
223-
shm_handles = self._configure_server()
232+
self._configure_server()
224233

225234
iu.shm_basic_infer(
226235
self,
227236
self.triton_client,
228-
shm_handles[0],
229-
shm_handles[1],
230-
shm_handles[2],
231-
shm_handles[3],
237+
self._shm_handles[0],
238+
self._shm_handles[1],
239+
self._shm_handles[2],
240+
self._shm_handles[3],
232241
error_msg,
233242
protocol=self.protocol,
234243
use_system_shared_memory=True,
@@ -247,13 +256,13 @@ def test_register_after_inference(self):
247256
self.assertTrue(len(shm_status) == 5)
248257
else:
249258
self.assertTrue(len(shm_status.regions) == 5)
250-
shm_handles.append(shm_ip2_handle)
251-
self._cleanup_server(shm_handles)
259+
self._shm_handles.append(shm_ip2_handle)
260+
self._cleanup_shm_handles()
252261

253262
def test_too_big_shm(self):
254263
# Shared memory input region larger than needed - Throws error
255264
error_msg = []
256-
shm_handles = self._configure_server()
265+
self._configure_server()
257266
shm_ip2_handle = shm.create_shared_memory_region(
258267
"input2_data", "/input2_data", 128
259268
)
@@ -264,10 +273,10 @@ def test_too_big_shm(self):
264273
iu.shm_basic_infer(
265274
self,
266275
self.triton_client,
267-
shm_handles[0],
276+
self._shm_handles[0],
268277
shm_ip2_handle,
269-
shm_handles[2],
270-
shm_handles[3],
278+
self._shm_handles[2],
279+
self._shm_handles[3],
271280
error_msg,
272281
big_shm_name="input2_data",
273282
big_shm_size=128,
@@ -279,33 +288,33 @@ def test_too_big_shm(self):
279288
"input byte size mismatch for input 'INPUT1' for model 'simple'. Expected 64, got 128",
280289
error_msg[-1],
281290
)
282-
shm_handles.append(shm_ip2_handle)
283-
self._cleanup_server(shm_handles)
291+
self._shm_handles.append(shm_ip2_handle)
292+
self._cleanup_shm_handles()
284293

285294
def test_mixed_raw_shm(self):
286295
# Mix of shared memory and RAW inputs
287296
error_msg = []
288-
shm_handles = self._configure_server()
297+
self._configure_server()
289298
input1_data = np.ones(shape=16, dtype=np.int32)
290299

291300
iu.shm_basic_infer(
292301
self,
293302
self.triton_client,
294-
shm_handles[0],
303+
self._shm_handles[0],
295304
[input1_data],
296-
shm_handles[2],
297-
shm_handles[3],
305+
self._shm_handles[2],
306+
self._shm_handles[3],
298307
error_msg,
299308
protocol=self.protocol,
300309
use_system_shared_memory=True,
301310
)
302311
if len(error_msg) > 0:
303312
raise Exception(error_msg[-1])
304-
self._cleanup_server(shm_handles)
313+
self._cleanup_shm_handles()
305314

306315
def test_unregisterall(self):
307316
# Unregister all shared memory blocks
308-
shm_handles = self._configure_server()
317+
self._configure_server()
309318
status_before = self.triton_client.get_system_shared_memory_status()
310319
if self.protocol == "http":
311320
self.assertTrue(len(status_before) == 4)
@@ -317,12 +326,12 @@ def test_unregisterall(self):
317326
self.assertTrue(len(status_after) == 0)
318327
else:
319328
self.assertTrue(len(status_after.regions) == 0)
320-
self._cleanup_server(shm_handles)
329+
self._cleanup_shm_handles()
321330

322331
def test_infer_offset_out_of_bound(self):
323332
# Shared memory offset outside output region - Throws error
324333
error_msg = []
325-
shm_handles = self._configure_server()
334+
self._configure_server()
326335
if self.protocol == "http":
327336
# -32 when placed in an int64 signed type, to get a negative offset
328337
# by overflowing
@@ -335,10 +344,10 @@ def test_infer_offset_out_of_bound(self):
335344
iu.shm_basic_infer(
336345
self,
337346
self.triton_client,
338-
shm_handles[0],
339-
shm_handles[1],
340-
shm_handles[2],
341-
shm_handles[3],
347+
self._shm_handles[0],
348+
self._shm_handles[1],
349+
self._shm_handles[2],
350+
self._shm_handles[3],
342351
error_msg,
343352
shm_output_offset=offset,
344353
protocol=self.protocol,
@@ -347,22 +356,22 @@ def test_infer_offset_out_of_bound(self):
347356

348357
self.assertEqual(len(error_msg), 1)
349358
self.assertIn("Invalid offset for shared memory region", error_msg[0])
350-
self._cleanup_server(shm_handles)
359+
self._cleanup_shm_handles()
351360

352361
def test_infer_byte_size_out_of_bound(self):
353362
# Shared memory byte_size outside output region - Throws error
354363
error_msg = []
355-
shm_handles = self._configure_server()
364+
self._configure_server()
356365
offset = 60
357366
byte_size = self.DEFAULT_SHM_BYTE_SIZE
358367

359368
iu.shm_basic_infer(
360369
self,
361370
self.triton_client,
362-
shm_handles[0],
363-
shm_handles[1],
364-
shm_handles[2],
365-
shm_handles[3],
371+
self._shm_handles[0],
372+
self._shm_handles[1],
373+
self._shm_handles[2],
374+
self._shm_handles[3],
366375
error_msg,
367376
shm_output_offset=offset,
368377
shm_output_byte_size=byte_size,
@@ -373,7 +382,7 @@ def test_infer_byte_size_out_of_bound(self):
373382
self.assertIn(
374383
"Invalid offset + byte size for shared memory region", error_msg[0]
375384
)
376-
self._cleanup_server(shm_handles)
385+
self._cleanup_shm_handles()
377386

378387
def test_register_out_of_bound(self):
379388
create_byte_size = self.DEFAULT_SHM_BYTE_SIZE
@@ -520,7 +529,7 @@ def _test_shm_not_found(self):
520529
def test_unregister_shm_during_inference_http(self):
521530
try:
522531
self.triton_client.unregister_system_shared_memory()
523-
shm_handles = self._configure_server()
532+
self._configure_server()
524533

525534
inputs = [
526535
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
@@ -554,12 +563,12 @@ def test_unregister_shm_during_inference_http(self):
554563
self._test_shm_not_found()
555564

556565
finally:
557-
self._cleanup_server(shm_handles)
566+
self._cleanup_shm_handles()
558567

559568
def test_unregister_shm_during_inference_grpc(self):
560569
try:
561570
self.triton_client.unregister_system_shared_memory()
562-
shm_handles = self._configure_server()
571+
self._configure_server()
563572

564573
inputs = [
565574
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
@@ -608,7 +617,7 @@ def callback(user_data, result, error):
608617
self._test_shm_not_found()
609618

610619
finally:
611-
self._cleanup_server(shm_handles)
620+
self._cleanup_shm_handles()
612621

613622

614623
if __name__ == "__main__":

0 commit comments

Comments
 (0)