55##
66
77import sys
8+ from multiprocessing import shared_memory
89
910import numpy as np
10- import sysv_ipc
1111
12- from .utils import mpi_data_type , random_shm_key
12+ from .utils import (
13+ mpi_data_type ,
14+ random_shm_key ,
15+ remove_shm_from_resource_tracker ,
16+ )
17+
18+ # Monkey patch resource_tracker. Remove once upstream CPython
19+ # changes are merged.
20+ remove_shm_from_resource_tracker ()
1321
1422
1523class MPIShared (object ):
@@ -149,7 +157,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
149157 if self ._rank == 0 :
150158 # Get a random 64bit integer between the supported range of keys
151159 self ._shm_index = random_shm_key ()
152- # Name, just used for printing
160+ # Name, used as global tag.
153161 self ._name = f"MPIShared_{ self ._shm_index } "
154162 if self ._comm is not None :
155163 self ._shm_index = self ._comm .bcast (self ._shm_index , root = 0 )
@@ -177,10 +185,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
177185 # First rank on each node creates the buffer
178186 if self ._noderank == 0 :
179187 try :
180- self ._shmem = sysv_ipc .SharedMemory (
181- self ._shm_index ,
182- flags = sysv_ipc .IPC_CREX ,
183- size = int (nbytes ),
188+ self ._shmem = shared_memory .SharedMemory (
189+ name = self ._name , create = True , size = int (nbytes ),
184190 )
185191 except Exception as e :
186192 msg = "Process {}: {}" .format (self ._rank , self ._name )
@@ -199,8 +205,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
199205 # Other ranks on the node attach
200206 if self ._noderank != 0 :
201207 try :
202- self ._shmem = sysv_ipc .SharedMemory (
203- self ._shm_index , flags = 0 , size = 0
208+ self ._shmem = shared_memory .SharedMemory (
209+ name = self ._name , create = False , size = int ( nbytes )
204210 )
205211 except Exception as e :
206212 msg = "Process {}: {}" .format (self ._rank , self ._name )
@@ -216,7 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
216222 self ._flat = np .ndarray (
217223 self ._n ,
218224 dtype = self ._dtype ,
219- buffer = self ._shmem ,
225+ buffer = self ._shmem . buf ,
220226 )
221227
222228 # Initialize to zero.
@@ -230,19 +236,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
230236 if self ._nodecomm is not None :
231237 self ._nodecomm .barrier ()
232238
233- # Now the rank zero process will call remove() to mark the shared
234- # memory segment for removal. However, this will not actually
235- # be removed until all processes detach.
236- if self ._noderank == 0 :
237- try :
238- self ._shmem .remove ()
239- except sysv_ipc .ExistentialError :
240- msg = "Process {}: {}" .format (self ._rank , self ._name )
241- msg += " failed to remove shared memory"
242- msg += ": {}" .format (e )
243- print (msg , flush = True )
244- raise
245-
246239 def __del__ (self ):
247240 self .close ()
248241
@@ -370,7 +363,9 @@ def close(self):
370363 del self ._flat
371364 if hasattr (self , "_shmem" ):
372365 if self ._shmem is not None :
373- self ._shmem .detach ()
366+ self ._shmem .close ()
367+ if self ._noderank == 0 :
368+ self ._shmem .unlink ()
374369 del self ._shmem
375370 self ._shmem = None
376371 self ._flat = None
0 commit comments