Skip to content

Commit 1c66455

Browse files
committed
(fix) make pool test work
1 parent 2598cc7 commit 1c66455

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

numcodecs/tests/test_shuffle.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from multiprocessing import Pool
2-
from multiprocessing.pool import ThreadPool
1+
from __future__ import annotations
2+
from multiprocessing.pool import ThreadPool, Pool
3+
from typing import Literal
34

45
import numpy as np
56
import pytest
@@ -87,14 +88,17 @@ def _decode_worker(enc: bytes) -> np.ndarray:
8788
return compressor.decode(enc)
8889

8990

90-
@pytest.mark.parametrize('pool', [Pool, ThreadPool])
91-
def test_multiprocessing(pool: type[Pool | ThreadPool]) -> None:
91+
@pytest.mark.parametrize('pool_type', ['processes', 'threads'])
92+
def test_multiprocessing(pool_type: Literal['processes', 'threads']) -> None:
9293
data = np.arange(1000000)
9394
enc = _encode_worker(data)
9495

95-
pool = pool(5)
96-
97-
# test with process pool and thread pool
96+
if pool_type == 'processes':
97+
pool = Pool(5)
98+
elif pool_type == 'threads':
99+
pool = ThreadPool(5)
100+
else:
101+
raise ValueError(f"invalid pool_type: {pool_type}")
98102

99103
# test encoding
100104
enc_results = pool.map(_encode_worker, [data] * 5)

0 commit comments

Comments
 (0)