|
1 | 1 | import dataclasses |
2 | 2 | import json |
3 | 3 | import math |
| 4 | +import multiprocessing as mp |
4 | 5 | import pickle |
5 | 6 | import re |
| 7 | +import sys |
6 | 8 | from itertools import accumulate |
7 | 9 | from typing import TYPE_CHECKING, Any, Literal |
8 | 10 | from unittest import mock |
@@ -1388,16 +1390,33 @@ def _index_array(arr: Array, index: Any) -> Any: |
1388 | 1390 | return arr[index] |
1389 | 1391 |
|
1390 | 1392 |
|
| 1393 | +@pytest.mark.parametrize( |
| 1394 | + "method", |
| 1395 | + [ |
| 1396 | + pytest.param( |
| 1397 | + "fork", |
| 1398 | + marks=pytest.mark.skipif( |
| 1399 | + sys.platform in ("win32", "darwin"), reason="fork not supported on Windows or OSX" |
| 1400 | + ), |
| 1401 | + ), |
| 1402 | + "spawn", |
| 1403 | + pytest.param( |
| 1404 | + "forkserver", |
| 1405 | + marks=pytest.mark.skipif( |
| 1406 | + sys.platform == "win32", reason="forkserver not supported on Windows" |
| 1407 | + ), |
| 1408 | + ), |
| 1409 | + ], |
| 1410 | +) |
1391 | 1411 | @pytest.mark.parametrize("store", ["local"], indirect=True) |
1392 | | -def test_multiprocessing(store: Store) -> None: |
| 1412 | +def test_multiprocessing(store: Store, method: Literal["fork", "spawn"]) -> None: |
1393 | 1413 | """ |
1394 | 1414 | Test that arrays can be pickled and indexed in child processes |
1395 | 1415 | """ |
1396 | 1416 | data = np.arange(100) |
1397 | 1417 | arr = zarr.create_array(store=store, data=data) |
1398 | | - from multiprocessing import Pool |
1399 | | - |
1400 | | - pool = Pool() |
| 1418 | + ctx = mp.get_context(method) |
| 1419 | + pool = ctx.Pool() |
1401 | 1420 |
|
1402 | | - results = pool.starmap(_index_array, [(arr, slice(len(data)))] * 3) |
| 1421 | + results = pool.starmap(_index_array, [(arr, slice(len(data)))]) |
1403 | 1422 | assert all(np.array_equal(r, data) for r in results) |
0 commit comments