1616)
1717from zarr .core .buffer import Buffer
1818from zarr .core .buffer .core import BufferPrototype
19+ from zarr .core .config import config
1920
2021if TYPE_CHECKING :
2122 from collections .abc import AsyncGenerator , Coroutine , Iterable
@@ -299,6 +300,7 @@ async def _make_bounded_requests(
299300 path : str ,
300301 requests : list [_BoundedRequest ],
301302 prototype : BufferPrototype ,
303+ semaphore : asyncio .Semaphore ,
302304) -> list [_Response ]:
303305 """Make all bounded requests for a specific file.
304306
@@ -310,7 +312,8 @@ async def _make_bounded_requests(
310312
311313 starts = [r ["start" ] for r in requests ]
312314 ends = [r ["end" ] for r in requests ]
313- responses = await obs .get_ranges_async (store , path = path , starts = starts , ends = ends )
315+ async with semaphore :
316+ responses = await obs .get_ranges_async (store , path = path , starts = starts , ends = ends )
314317
315318 buffer_responses : list [_Response ] = []
316319 for request , response in zip (requests , responses , strict = True ):
@@ -328,6 +331,7 @@ async def _make_other_request(
328331 store : _UpstreamObjectStore ,
329332 request : _OtherRequest ,
330333 prototype : BufferPrototype ,
334+ semaphore : asyncio .Semaphore ,
331335) -> list [_Response ]:
332336 """Make suffix or offset requests.
333337
@@ -336,11 +340,13 @@ async def _make_other_request(
336340 """
337341 import obstore as obs
338342
339- if request ["range" ] is None :
340- resp = await obs .get_async (store , request ["path" ])
341- else :
342- resp = await obs .get_async (store , request ["path" ], options = {"range" : request ["range" ]})
343- buffer = await resp .bytes_async ()
343+ async with semaphore :
344+ if request ["range" ] is None :
345+ resp = await obs .get_async (store , request ["path" ])
346+ else :
347+ resp = await obs .get_async (store , request ["path" ], options = {"range" : request ["range" ]})
348+ buffer = await resp .bytes_async ()
349+
344350 return [
345351 {
346352 "original_request_index" : request ["original_request_index" ],
@@ -401,17 +407,19 @@ async def _get_partial_values(
401407 else :
402408 raise ValueError (f"Unsupported range input: { byte_range } " )
403409
410+ semaphore = asyncio .Semaphore (config .get ("async.concurrency" ))
411+
404412 futs : list [Coroutine [Any , Any , list [_Response ]]] = []
405413 for path , bounded_ranges in per_file_bounded_requests .items ():
406- futs .append (_make_bounded_requests (store , path , bounded_ranges , prototype ))
414+ futs .append (
415+ _make_bounded_requests (store , path , bounded_ranges , prototype , semaphore = semaphore )
416+ )
407417
408418 for request in other_requests :
409- futs .append (_make_other_request (store , request , prototype )) # noqa: PERF401
419+ futs .append (_make_other_request (store , request , prototype , semaphore = semaphore )) # noqa: PERF401
410420
411421 buffers : list [Buffer | None ] = [None ] * len (key_ranges )
412422
413- # TODO: this gather a list of list of Response; not sure if there's a way to
414- # unpack these lists inside of an `asyncio.gather`?
415423 for responses in await asyncio .gather (* futs ):
416424 for resp in responses :
417425 buffers [resp ["original_request_index" ]] = resp ["buffer" ]
0 commit comments