Skip to content

Commit 2f8b88a

Browse files
authored
Multiprocessing support (#2815)
* add failing multiprocessing test * add hook to reset global vars after fork * parametrize multiprocessing test over different methods * guard execution of register_at_fork with a hasattr check * exempt runs-in-a-forked-process code from coverage * update literal type
1 parent 8b77464 commit 2f8b88a

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/zarr/core/sync.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import atexit
55
import logging
6+
import os
67
import threading
78
from concurrent.futures import ThreadPoolExecutor, wait
89
from typing import TYPE_CHECKING, TypeVar
@@ -89,6 +90,26 @@ def cleanup_resources() -> None:
8990
atexit.register(cleanup_resources)
9091

9192

93+
def reset_resources_after_fork() -> None:
94+
"""
95+
Ensure that global resources are reset after a fork. Without this function,
96+
forked processes will retain invalid references to the parent process's resources.
97+
"""
98+
global loop, iothread, _executor
99+
# These lines are excluded from coverage because this function only runs in a child process,
100+
# which is not observed by the test coverage instrumentation. Despite the apparent lack of
101+
# test coverage, this function should be adequately tested by any test that uses Zarr IO with
102+
# multiprocessing.
103+
loop[0] = None # pragma: no cover
104+
iothread[0] = None # pragma: no cover
105+
_executor = None # pragma: no cover
106+
107+
108+
# this is only available on certain operating systems
109+
if hasattr(os, "register_at_fork"):
110+
os.register_at_fork(after_in_child=reset_resources_after_fork)
111+
112+
92113
async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException:
93114
"""
94115
Await a coroutine and return the result of running it. If awaiting the coroutine raises an

tests/test_array.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import dataclasses
22
import json
33
import math
4+
import multiprocessing as mp
45
import pickle
56
import re
7+
import sys
68
from itertools import accumulate
79
from typing import TYPE_CHECKING, Any, Literal
810
from unittest import mock
@@ -1382,3 +1384,39 @@ def test_roundtrip_numcodecs() -> None:
13821384
metadata = root["test"].metadata.to_dict()
13831385
expected = (*filters, BYTES_CODEC, *compressors)
13841386
assert metadata["codecs"] == expected
1387+
1388+
1389+
def _index_array(arr: Array, index: Any) -> Any:
1390+
return arr[index]
1391+
1392+
1393+
@pytest.mark.parametrize(
1394+
"method",
1395+
[
1396+
pytest.param(
1397+
"fork",
1398+
marks=pytest.mark.skipif(
1399+
sys.platform in ("win32", "darwin"), reason="fork not supported on Windows or OSX"
1400+
),
1401+
),
1402+
"spawn",
1403+
pytest.param(
1404+
"forkserver",
1405+
marks=pytest.mark.skipif(
1406+
sys.platform == "win32", reason="forkserver not supported on Windows"
1407+
),
1408+
),
1409+
],
1410+
)
1411+
@pytest.mark.parametrize("store", ["local"], indirect=True)
1412+
def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkserver"]) -> None:
1413+
"""
1414+
Test that arrays can be pickled and indexed in child processes
1415+
"""
1416+
data = np.arange(100)
1417+
arr = zarr.create_array(store=store, data=data)
1418+
ctx = mp.get_context(method)
1419+
pool = ctx.Pool()
1420+
1421+
results = pool.starmap(_index_array, [(arr, slice(len(data)))])
1422+
assert all(np.array_equal(r, data) for r in results)

0 commit comments

Comments
 (0)