@@ -692,6 +692,8 @@ async def getitem(
692692 metadata = _build_metadata_v2 (zarray , zattrs )
693693 return _build_node_v2 (metadata = metadata , store_path = store_path )
694694 else :
695+ # this is just for mypy
696+ assert zgroup is not None
695697 metadata = _build_metadata_v2 (zgroup , zattrs )
696698 return _build_node_v2 (metadata = metadata , store_path = store_path )
697699 else :
@@ -1328,7 +1330,11 @@ async def _members(
13281330 )
13291331
13301332 raise ValueError (msg )
1331- async for member in _iter_members_deep (self , max_depth = max_depth , skip_keys = skip_keys ):
1333+ # enforce a concurrency limit by passing a semaphore to all the recursive functions
1334+ semaphore = asyncio .Semaphore (config .get ("async.concurrency" ))
1335+ async for member in _iter_members_deep (
1336+ self , max_depth = max_depth , skip_keys = skip_keys , semaphore = semaphore
1337+ ):
13321338 yield member
13331339
13341340 async def keys (self ) -> AsyncGenerator [str , None ]:
@@ -2638,8 +2644,20 @@ async def members_recursive(
26382644 return members_flat
26392645
26402646
2647+ async def _getitem_semaphore (
2648+ node : AsyncGroup , key : str , semaphore : asyncio .Semaphore | None
2649+ ) -> AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup :
2650+ if semaphore is not None :
2651+ async with semaphore :
2652+ return await node .getitem (key )
2653+ else :
2654+ return await node .getitem (key )
2655+
2656+
26412657async def iter_members (
2642- node : AsyncGroup , skip_keys : tuple [str , ...]
2658+ node : AsyncGroup ,
2659+ skip_keys : tuple [str , ...],
2660+ semaphore : asyncio .Semaphore | None ,
26432661) -> AsyncGenerator [
26442662 tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None
26452663]:
@@ -2651,7 +2669,10 @@ async def iter_members(
26512669 keys = [key async for key in node .store .list_dir (node .path )]
26522670 keys_filtered = tuple (filter (lambda v : v not in skip_keys , keys ))
26532671
2654- node_tasks = tuple (asyncio .create_task (node .getitem (key ), name = key ) for key in keys_filtered )
2672+ node_tasks = tuple (
2673+ asyncio .create_task (_getitem_semaphore (node , key , semaphore ), name = key )
2674+ for key in keys_filtered
2675+ )
26552676
26562677 for fetched_node_coro in asyncio .as_completed (node_tasks ):
26572678 try :
@@ -2674,7 +2695,11 @@ async def iter_members(
26742695
26752696
26762697async def _iter_members_deep (
2677- group : AsyncGroup , * , max_depth : int | None , skip_keys : tuple [str , ...]
2698+ group : AsyncGroup ,
2699+ * ,
2700+ max_depth : int | None ,
2701+ skip_keys : tuple [str , ...],
2702+ semaphore : asyncio .Semaphore | None = None ,
26782703) -> AsyncGenerator [
26792704 tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None
26802705]:
@@ -2690,10 +2715,12 @@ async def _iter_members_deep(
26902715 new_depth = None
26912716 else :
26922717 new_depth = max_depth - 1
2693- async for name , node in iter_members (group , skip_keys = skip_keys ):
2718+ async for name , node in iter_members (group , skip_keys = skip_keys , semaphore = semaphore ):
26942719 yield name , node
26952720 if isinstance (node , AsyncGroup ) and do_recursion :
2696- to_recurse [name ] = _iter_members_deep (node , max_depth = new_depth , skip_keys = skip_keys )
2721+ to_recurse [name ] = _iter_members_deep (
2722+ node , max_depth = new_depth , skip_keys = skip_keys , semaphore = semaphore
2723+ )
26972724
26982725 for prefix , subgroup_iter in to_recurse .items ():
26992726 async for name , node in subgroup_iter :
0 commit comments