Skip to content

Commit 9691102

Browse files
committed
add concurrency limit
1 parent cba42f3 commit 9691102

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

src/zarr/core/group.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,8 @@ async def getitem(
692692
metadata = _build_metadata_v2(zarray, zattrs)
693693
return _build_node_v2(metadata=metadata, store_path=store_path)
694694
else:
695+
# this is just for mypy
696+
assert zgroup is not None
695697
metadata = _build_metadata_v2(zgroup, zattrs)
696698
return _build_node_v2(metadata=metadata, store_path=store_path)
697699
else:
@@ -1328,7 +1330,11 @@ async def _members(
13281330
)
13291331

13301332
raise ValueError(msg)
1331-
async for member in _iter_members_deep(self, max_depth=max_depth, skip_keys=skip_keys):
1333+
# enforce a concurrency limit by passing a semaphore to all the recursive functions
1334+
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
1335+
async for member in _iter_members_deep(
1336+
self, max_depth=max_depth, skip_keys=skip_keys, semaphore=semaphore
1337+
):
13321338
yield member
13331339

13341340
async def keys(self) -> AsyncGenerator[str, None]:
@@ -2638,8 +2644,20 @@ async def members_recursive(
26382644
return members_flat
26392645

26402646

2647+
async def _getitem_semaphore(
2648+
node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None
2649+
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup:
2650+
if semaphore is not None:
2651+
async with semaphore:
2652+
return await node.getitem(key)
2653+
else:
2654+
return await node.getitem(key)
2655+
2656+
26412657
async def iter_members(
2642-
node: AsyncGroup, skip_keys: tuple[str, ...]
2658+
node: AsyncGroup,
2659+
skip_keys: tuple[str, ...],
2660+
semaphore: asyncio.Semaphore | None,
26432661
) -> AsyncGenerator[
26442662
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
26452663
]:
@@ -2651,7 +2669,10 @@ async def iter_members(
26512669
keys = [key async for key in node.store.list_dir(node.path)]
26522670
keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys))
26532671

2654-
node_tasks = tuple(asyncio.create_task(node.getitem(key), name=key) for key in keys_filtered)
2672+
node_tasks = tuple(
2673+
asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key)
2674+
for key in keys_filtered
2675+
)
26552676

26562677
for fetched_node_coro in asyncio.as_completed(node_tasks):
26572678
try:
@@ -2674,7 +2695,11 @@ async def iter_members(
26742695

26752696

26762697
async def _iter_members_deep(
2677-
group: AsyncGroup, *, max_depth: int | None, skip_keys: tuple[str, ...]
2698+
group: AsyncGroup,
2699+
*,
2700+
max_depth: int | None,
2701+
skip_keys: tuple[str, ...],
2702+
semaphore: asyncio.Semaphore | None = None,
26782703
) -> AsyncGenerator[
26792704
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
26802705
]:
@@ -2690,10 +2715,12 @@ async def _iter_members_deep(
26902715
new_depth = None
26912716
else:
26922717
new_depth = max_depth - 1
2693-
async for name, node in iter_members(group, skip_keys=skip_keys):
2718+
async for name, node in iter_members(group, skip_keys=skip_keys, semaphore=semaphore):
26942719
yield name, node
26952720
if isinstance(node, AsyncGroup) and do_recursion:
2696-
to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys)
2721+
to_recurse[name] = _iter_members_deep(
2722+
node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore
2723+
)
26972724

26982725
for prefix, subgroup_iter in to_recurse.items():
26992726
async for name, node in subgroup_iter:

0 commit comments

Comments
 (0)