Skip to content

Commit 02cd641

Browse files
authored
Merge pull request #24 from tskisner/mmap
Use Python SharedMemory as the backend
2 parents 2ca8e38 + 95e19fa commit 02cd641

File tree

4 files changed

+54
-30
lines changed

4 files changed

+54
-30
lines changed

pshmem/shmem.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,19 @@
55
##
66

77
import sys
8+
from multiprocessing import shared_memory
89

910
import numpy as np
10-
import sysv_ipc
1111

12-
from .utils import mpi_data_type, random_shm_key
12+
from .utils import (
13+
mpi_data_type,
14+
random_shm_key,
15+
remove_shm_from_resource_tracker,
16+
)
17+
18+
# Monkey patch resource_tracker. Remove once upstream CPython
19+
# changes are merged.
20+
remove_shm_from_resource_tracker()
1321

1422

1523
class MPIShared(object):
@@ -149,7 +157,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
149157
if self._rank == 0:
150158
# Get a random 64bit integer between the supported range of keys
151159
self._shm_index = random_shm_key()
152-
# Name, just used for printing
160+
# Name, used as global tag.
153161
self._name = f"MPIShared_{self._shm_index}"
154162
if self._comm is not None:
155163
self._shm_index = self._comm.bcast(self._shm_index, root=0)
@@ -177,10 +185,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
177185
# First rank on each node creates the buffer
178186
if self._noderank == 0:
179187
try:
180-
self._shmem = sysv_ipc.SharedMemory(
181-
self._shm_index,
182-
flags=sysv_ipc.IPC_CREX,
183-
size=int(nbytes),
188+
self._shmem = shared_memory.SharedMemory(
189+
name=self._name, create=True, size=int(nbytes),
184190
)
185191
except Exception as e:
186192
msg = "Process {}: {}".format(self._rank, self._name)
@@ -199,8 +205,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
199205
# Other ranks on the node attach
200206
if self._noderank != 0:
201207
try:
202-
self._shmem = sysv_ipc.SharedMemory(
203-
self._shm_index, flags=0, size=0
208+
self._shmem = shared_memory.SharedMemory(
209+
name=self._name, create=False, size=int(nbytes)
204210
)
205211
except Exception as e:
206212
msg = "Process {}: {}".format(self._rank, self._name)
@@ -216,7 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
216222
self._flat = np.ndarray(
217223
self._n,
218224
dtype=self._dtype,
219-
buffer=self._shmem,
225+
buffer=self._shmem.buf,
220226
)
221227

222228
# Initialize to zero.
@@ -230,19 +236,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
230236
if self._nodecomm is not None:
231237
self._nodecomm.barrier()
232238

233-
# Now the rank zero process will call remove() to mark the shared
234-
# memory segment for removal. However, this will not actually
235-
# be removed until all processes detach.
236-
if self._noderank == 0:
237-
try:
238-
self._shmem.remove()
239-
except sysv_ipc.ExistentialError:
240-
msg = "Process {}: {}".format(self._rank, self._name)
241-
msg += " failed to remove shared memory"
242-
msg += ": {}".format(e)
243-
print(msg, flush=True)
244-
raise
245-
246239
def __del__(self):
247240
self.close()
248241

@@ -370,7 +363,9 @@ def close(self):
370363
del self._flat
371364
if hasattr(self, "_shmem"):
372365
if self._shmem is not None:
373-
self._shmem.detach()
366+
self._shmem.close()
367+
if self._noderank == 0:
368+
self._shmem.unlink()
374369
del self._shmem
375370
self._shmem = None
376371
self._flat = None

pshmem/test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,14 @@ def test_zero(self):
432432
# dims = (200, 1000000)
433433
# dt = np.float64
434434
# shm = MPIShared(dims, dt, self.comm)
435+
# if self.comm is None or self.comm.rank == 0:
436+
# temp = np.ones(dims, dtype=dt)
437+
# else:
438+
# temp = None
439+
# shm.set(temp, fromrank=0)
440+
# del temp
435441
# import time
436442
# time.sleep(60)
437-
# shm.close()
438443
# del shm
439444
# return
440445

pshmem/utils.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
##
66

77
import random
8+
import sys
9+
# Import for monkey patching resource tracker
10+
from multiprocessing import resource_tracker
811

912
import numpy as np
10-
import sysv_ipc
1113

1214

1315
def mpi_data_type(comm, dt):
@@ -48,7 +50,7 @@ def mpi_data_type(comm, dt):
4850

4951

5052
def random_shm_key():
51-
"""Get a random 64bit integer in the range supported by shmget()
53+
"""Get a random positive integer for using in shared memory naming.
5254
5355
The python random library is used, and seeded with the default source
5456
(either system time or os.urandom).
@@ -57,8 +59,30 @@ def random_shm_key():
5759
(int): The random integer.
5860
5961
"""
60-
min_val = sysv_ipc.KEY_MIN
61-
max_val = sysv_ipc.KEY_MAX
62+
min_val = 0
63+
max_val = sys.maxsize
6264
# Seed with default source of randomness
6365
random.seed(a=None)
6466
return random.randint(min_val, max_val)
67+
68+
69+
def remove_shm_from_resource_tracker():
70+
"""Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be tracked
71+
72+
More details at: https://bugs.python.org/issue38119
73+
"""
74+
75+
def fix_register(name, rtype):
76+
if rtype == "shared_memory":
77+
return
78+
return resource_tracker._resource_tracker.register(self, name, rtype)
79+
resource_tracker.register = fix_register
80+
81+
def fix_unregister(name, rtype):
82+
if rtype == "shared_memory":
83+
return
84+
return resource_tracker._resource_tracker.unregister(self, name, rtype)
85+
resource_tracker.unregister = fix_unregister
86+
87+
if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
88+
del resource_tracker._CLEANUP_FUNCS["shared_memory"]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def readme():
2525
scripts=None,
2626
license="BSD",
2727
python_requires=">=3.8.0",
28-
install_requires=["numpy", "sysv_ipc"],
28+
install_requires=["numpy"],
2929
extras_require={"mpi": ["mpi4py>=3.0"]},
3030
cmdclass=versioneer.get_cmdclass(),
3131
classifiers=[

0 commit comments

Comments
 (0)