@@ -50,6 +50,10 @@ class SystemSharedMemoryTestBase(tu.TestResultCollector):
5050
5151 def setUp (self ):
5252 self ._setup_client ()
53+ self ._shm_handles = []
54+
55+ def tearDown (self ):
56+ self ._cleanup_shm_handles ()
5357
5458 def _setup_client (self ):
5559 self .protocol = os .environ .get ("CLIENT_TYPE" , "http" )
@@ -89,6 +93,7 @@ def _configure_server(
8993 Offset into the shared memory object to start the registered region.
9094
9195 """
96+ self ._cleanup_shm_handles ()
9297 shm_ip0_handle = shm .create_shared_memory_region (
9398 "input0_data" , "/input0_data" , create_byte_size
9499 )
@@ -101,6 +106,12 @@ def _configure_server(
101106 shm_op1_handle = shm .create_shared_memory_region (
102107 "output1_data" , "/output1_data" , create_byte_size
103108 )
109+ self ._shm_handles = [
110+ shm_ip0_handle ,
111+ shm_ip1_handle ,
112+ shm_op0_handle ,
113+ shm_op1_handle ,
114+ ]
104115 # Implicit assumption that input and output byte_sizes are 64 bytes for now
105116 input0_data = np .arange (start = 0 , stop = 16 , dtype = np .int32 )
106117 input1_data = np .ones (shape = 16 , dtype = np .int32 )
@@ -118,23 +129,21 @@ def _configure_server(
118129 self .triton_client .register_system_shared_memory (
119130 "output1_data" , "/output1_data" , register_byte_size , offset = register_offset
120131 )
121- return [shm_ip0_handle , shm_ip1_handle , shm_op0_handle , shm_op1_handle ]
122132
123- def _cleanup_server (self , shm_handles ):
124- for shm_handle in shm_handles :
133+ def _cleanup_shm_handles (self ):
134+ for shm_handle in self . _shm_handles :
125135 shm .destroy_shared_memory_region (shm_handle )
136+ self ._shm_handles = []
126137
127138
128139class SharedMemoryTest (SystemSharedMemoryTestBase ):
129140 def test_invalid_create_shm (self ):
130- # Raises error since tried to create invalid system shared memory region
131- try :
132- shm_op0_handle = shm .create_shared_memory_region (
133- "dummy_data" , "/dummy_data" , - 1
141+ with self .assertRaisesRegex (
142+ shm .SharedMemoryException , "unable to create the shared memory region"
143+ ):
144+ self ._shm_handles .append (
145+ shm .create_shared_memory_region ("dummy_data" , "/dummy_data" , - 1 )
134146 )
135- shm .destroy_shared_memory_region (shm_op0_handle )
136- except Exception as ex :
137- self .assertTrue (str (ex ) == "unable to initialize the size" )
138147
139148 def test_valid_create_set_register (self ):
140149 # Create a valid system shared memory region, fill data in it and register
@@ -195,14 +204,14 @@ def test_reregister_after_register(self):
195204 def test_unregister_after_inference (self ):
196205 # Unregister after inference
197206 error_msg = []
198- shm_handles = self ._configure_server ()
207+ self ._configure_server ()
199208 iu .shm_basic_infer (
200209 self ,
201210 self .triton_client ,
202- shm_handles [0 ],
203- shm_handles [1 ],
204- shm_handles [2 ],
205- shm_handles [3 ],
211+ self . _shm_handles [0 ],
212+ self . _shm_handles [1 ],
213+ self . _shm_handles [2 ],
214+ self . _shm_handles [3 ],
206215 error_msg ,
207216 protocol = self .protocol ,
208217 use_system_shared_memory = True ,
@@ -215,20 +224,20 @@ def test_unregister_after_inference(self):
215224 self .assertTrue (len (shm_status ) == 3 )
216225 else :
217226 self .assertTrue (len (shm_status .regions ) == 3 )
218- self ._cleanup_server ( shm_handles )
227+ self ._cleanup_shm_handles ( )
219228
220229 def test_register_after_inference (self ):
221230 # Register after inference
222231 error_msg = []
223- shm_handles = self ._configure_server ()
232+ self ._configure_server ()
224233
225234 iu .shm_basic_infer (
226235 self ,
227236 self .triton_client ,
228- shm_handles [0 ],
229- shm_handles [1 ],
230- shm_handles [2 ],
231- shm_handles [3 ],
237+ self . _shm_handles [0 ],
238+ self . _shm_handles [1 ],
239+ self . _shm_handles [2 ],
240+ self . _shm_handles [3 ],
232241 error_msg ,
233242 protocol = self .protocol ,
234243 use_system_shared_memory = True ,
@@ -247,13 +256,13 @@ def test_register_after_inference(self):
247256 self .assertTrue (len (shm_status ) == 5 )
248257 else :
249258 self .assertTrue (len (shm_status .regions ) == 5 )
250- shm_handles .append (shm_ip2_handle )
251- self ._cleanup_server ( shm_handles )
259+ self . _shm_handles .append (shm_ip2_handle )
260+ self ._cleanup_shm_handles ( )
252261
253262 def test_too_big_shm (self ):
254263 # Shared memory input region larger than needed - Throws error
255264 error_msg = []
256- shm_handles = self ._configure_server ()
265+ self ._configure_server ()
257266 shm_ip2_handle = shm .create_shared_memory_region (
258267 "input2_data" , "/input2_data" , 128
259268 )
@@ -264,10 +273,10 @@ def test_too_big_shm(self):
264273 iu .shm_basic_infer (
265274 self ,
266275 self .triton_client ,
267- shm_handles [0 ],
276+ self . _shm_handles [0 ],
268277 shm_ip2_handle ,
269- shm_handles [2 ],
270- shm_handles [3 ],
278+ self . _shm_handles [2 ],
279+ self . _shm_handles [3 ],
271280 error_msg ,
272281 big_shm_name = "input2_data" ,
273282 big_shm_size = 128 ,
@@ -279,33 +288,33 @@ def test_too_big_shm(self):
279288 "input byte size mismatch for input 'INPUT1' for model 'simple'. Expected 64, got 128" ,
280289 error_msg [- 1 ],
281290 )
282- shm_handles .append (shm_ip2_handle )
283- self ._cleanup_server ( shm_handles )
291+ self . _shm_handles .append (shm_ip2_handle )
292+ self ._cleanup_shm_handles ( )
284293
285294 def test_mixed_raw_shm (self ):
286295 # Mix of shared memory and RAW inputs
287296 error_msg = []
288- shm_handles = self ._configure_server ()
297+ self ._configure_server ()
289298 input1_data = np .ones (shape = 16 , dtype = np .int32 )
290299
291300 iu .shm_basic_infer (
292301 self ,
293302 self .triton_client ,
294- shm_handles [0 ],
303+ self . _shm_handles [0 ],
295304 [input1_data ],
296- shm_handles [2 ],
297- shm_handles [3 ],
305+ self . _shm_handles [2 ],
306+ self . _shm_handles [3 ],
298307 error_msg ,
299308 protocol = self .protocol ,
300309 use_system_shared_memory = True ,
301310 )
302311 if len (error_msg ) > 0 :
303312 raise Exception (error_msg [- 1 ])
304- self ._cleanup_server ( shm_handles )
313+ self ._cleanup_shm_handles ( )
305314
306315 def test_unregisterall (self ):
307316 # Unregister all shared memory blocks
308- shm_handles = self ._configure_server ()
317+ self ._configure_server ()
309318 status_before = self .triton_client .get_system_shared_memory_status ()
310319 if self .protocol == "http" :
311320 self .assertTrue (len (status_before ) == 4 )
@@ -317,12 +326,12 @@ def test_unregisterall(self):
317326 self .assertTrue (len (status_after ) == 0 )
318327 else :
319328 self .assertTrue (len (status_after .regions ) == 0 )
320- self ._cleanup_server ( shm_handles )
329+ self ._cleanup_shm_handles ( )
321330
322331 def test_infer_offset_out_of_bound (self ):
323332 # Shared memory offset outside output region - Throws error
324333 error_msg = []
325- shm_handles = self ._configure_server ()
334+ self ._configure_server ()
326335 if self .protocol == "http" :
327336 # -32 when placed in an int64 signed type, to get a negative offset
328337 # by overflowing
@@ -335,10 +344,10 @@ def test_infer_offset_out_of_bound(self):
335344 iu .shm_basic_infer (
336345 self ,
337346 self .triton_client ,
338- shm_handles [0 ],
339- shm_handles [1 ],
340- shm_handles [2 ],
341- shm_handles [3 ],
347+ self . _shm_handles [0 ],
348+ self . _shm_handles [1 ],
349+ self . _shm_handles [2 ],
350+ self . _shm_handles [3 ],
342351 error_msg ,
343352 shm_output_offset = offset ,
344353 protocol = self .protocol ,
@@ -347,22 +356,22 @@ def test_infer_offset_out_of_bound(self):
347356
348357 self .assertEqual (len (error_msg ), 1 )
349358 self .assertIn ("Invalid offset for shared memory region" , error_msg [0 ])
350- self ._cleanup_server ( shm_handles )
359+ self ._cleanup_shm_handles ( )
351360
352361 def test_infer_byte_size_out_of_bound (self ):
353362 # Shared memory byte_size outside output region - Throws error
354363 error_msg = []
355- shm_handles = self ._configure_server ()
364+ self ._configure_server ()
356365 offset = 60
357366 byte_size = self .DEFAULT_SHM_BYTE_SIZE
358367
359368 iu .shm_basic_infer (
360369 self ,
361370 self .triton_client ,
362- shm_handles [0 ],
363- shm_handles [1 ],
364- shm_handles [2 ],
365- shm_handles [3 ],
371+ self . _shm_handles [0 ],
372+ self . _shm_handles [1 ],
373+ self . _shm_handles [2 ],
374+ self . _shm_handles [3 ],
366375 error_msg ,
367376 shm_output_offset = offset ,
368377 shm_output_byte_size = byte_size ,
@@ -373,7 +382,7 @@ def test_infer_byte_size_out_of_bound(self):
373382 self .assertIn (
374383 "Invalid offset + byte size for shared memory region" , error_msg [0 ]
375384 )
376- self ._cleanup_server ( shm_handles )
385+ self ._cleanup_shm_handles ( )
377386
378387 def test_register_out_of_bound (self ):
379388 create_byte_size = self .DEFAULT_SHM_BYTE_SIZE
@@ -520,7 +529,7 @@ def _test_shm_not_found(self):
520529 def test_unregister_shm_during_inference_http (self ):
521530 try :
522531 self .triton_client .unregister_system_shared_memory ()
523- shm_handles = self ._configure_server ()
532+ self ._configure_server ()
524533
525534 inputs = [
526535 httpclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ),
@@ -554,12 +563,12 @@ def test_unregister_shm_during_inference_http(self):
554563 self ._test_shm_not_found ()
555564
556565 finally :
557- self ._cleanup_server ( shm_handles )
566+ self ._cleanup_shm_handles ( )
558567
559568 def test_unregister_shm_during_inference_grpc (self ):
560569 try :
561570 self .triton_client .unregister_system_shared_memory ()
562- shm_handles = self ._configure_server ()
571+ self ._configure_server ()
563572
564573 inputs = [
565574 grpcclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ),
@@ -608,7 +617,7 @@ def callback(user_data, result, error):
608617 self ._test_shm_not_found ()
609618
610619 finally :
611- self ._cleanup_server ( shm_handles )
620+ self ._cleanup_shm_handles ( )
612621
613622
614623if __name__ == "__main__" :
0 commit comments