Skip to content

Commit c45c7d1

Browse files
committed
Add semaphore
1 parent 6ce2258 commit c45c7d1

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

src/zarr/storage/_obstore.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from zarr.core.buffer import Buffer
1818
from zarr.core.buffer.core import BufferPrototype
19+
from zarr.core.config import config
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import AsyncGenerator, Coroutine, Iterable
@@ -299,6 +300,7 @@ async def _make_bounded_requests(
299300
path: str,
300301
requests: list[_BoundedRequest],
301302
prototype: BufferPrototype,
303+
semaphore: asyncio.Semaphore,
302304
) -> list[_Response]:
303305
"""Make all bounded requests for a specific file.
304306
@@ -310,7 +312,8 @@ async def _make_bounded_requests(
310312

311313
starts = [r["start"] for r in requests]
312314
ends = [r["end"] for r in requests]
313-
responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends)
315+
async with semaphore:
316+
responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends)
314317

315318
buffer_responses: list[_Response] = []
316319
for request, response in zip(requests, responses, strict=True):
@@ -328,6 +331,7 @@ async def _make_other_request(
328331
store: _UpstreamObjectStore,
329332
request: _OtherRequest,
330333
prototype: BufferPrototype,
334+
semaphore: asyncio.Semaphore,
331335
) -> list[_Response]:
332336
"""Make suffix or offset requests.
333337
@@ -336,11 +340,13 @@ async def _make_other_request(
336340
"""
337341
import obstore as obs
338342

339-
if request["range"] is None:
340-
resp = await obs.get_async(store, request["path"])
341-
else:
342-
resp = await obs.get_async(store, request["path"], options={"range": request["range"]})
343-
buffer = await resp.bytes_async()
343+
async with semaphore:
344+
if request["range"] is None:
345+
resp = await obs.get_async(store, request["path"])
346+
else:
347+
resp = await obs.get_async(store, request["path"], options={"range": request["range"]})
348+
buffer = await resp.bytes_async()
349+
344350
return [
345351
{
346352
"original_request_index": request["original_request_index"],
@@ -401,17 +407,19 @@ async def _get_partial_values(
401407
else:
402408
raise ValueError(f"Unsupported range input: {byte_range}")
403409

410+
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
411+
404412
futs: list[Coroutine[Any, Any, list[_Response]]] = []
405413
for path, bounded_ranges in per_file_bounded_requests.items():
406-
futs.append(_make_bounded_requests(store, path, bounded_ranges, prototype))
414+
futs.append(
415+
_make_bounded_requests(store, path, bounded_ranges, prototype, semaphore=semaphore)
416+
)
407417

408418
for request in other_requests:
409-
futs.append(_make_other_request(store, request, prototype)) # noqa: PERF401
419+
futs.append(_make_other_request(store, request, prototype, semaphore=semaphore)) # noqa: PERF401
410420

411421
buffers: list[Buffer | None] = [None] * len(key_ranges)
412422

413-
# TODO: this gather a list of list of Response; not sure if there's a way to
414-
# unpack these lists inside of an `asyncio.gather`?
415423
for responses in await asyncio.gather(*futs):
416424
for resp in responses:
417425
buffers[resp["original_request_index"]] = resp["buffer"]

0 commit comments

Comments
 (0)