Skip to content

Commit 7cbc500

Browse files
committed
Maybe fixup
1 parent 1cdfd6d commit 7cbc500

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

src/zarr/abc/store.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
77

88
from zarr.core.buffer.core import default_buffer_prototype
9+
from zarr.core.common import concurrent_map
10+
from zarr.core.config import config
911

1012
if TYPE_CHECKING:
1113
from collections.abc import AsyncGenerator, Iterable
@@ -453,9 +455,14 @@ async def getsize_prefix(self, prefix: str) -> int:
453455
-----
454456
``getsize_prefix`` is just provided as a potentially faster alternative to
455457
listing all the keys under a prefix calling :meth:`Store.getsize` on each.
458+
459+
In general, ``prefix`` should be the path of an Array or Group in the Store.
460+
Implementations may differ on the behavior when some other ``prefix``
461+
is provided.
456462
"""
457-
keys = [x async for x in self.list_prefix(prefix)]
458-
sizes = await gather(*[self.getsize(key) for key in keys])
463+
keys = ((x,) async for x in self.list_prefix(prefix))
464+
limit = config.get("async.concurrency")
465+
sizes = await concurrent_map(keys, self.getsize, limit=limit)
459466
return sum(sizes)
460467

461468

src/zarr/core/common.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import functools
55
import operator
6-
from collections.abc import Iterable, Mapping
6+
from collections.abc import AsyncIterable, Iterable, Mapping
77
from enum import Enum
88
from itertools import starmap
99
from typing import (
@@ -50,10 +50,15 @@ def product(tup: ChunkCoords) -> int:
5050

5151

5252
async def concurrent_map(
53-
items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None
53+
items: Iterable[T] | AsyncIterable[T],
54+
func: Callable[..., Awaitable[V]],
55+
limit: int | None = None,
5456
) -> list[V]:
5557
if limit is None:
56-
return await asyncio.gather(*list(starmap(func, items)))
58+
if isinstance(items, AsyncIterable):
59+
return await asyncio.gather(*list(starmap(func, [x async for x in items])))
60+
else:
61+
return await asyncio.gather(*list(starmap(func, items)))
5762

5863
else:
5964
sem = asyncio.Semaphore(limit)
@@ -62,7 +67,10 @@ async def run(item: tuple[Any]) -> V:
6267
async with sem:
6368
return await func(*item)
6469

65-
return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])
70+
if isinstance(items, AsyncIterable):
71+
return await asyncio.gather(*[asyncio.ensure_future(run(item)) async for item in items])
72+
else:
73+
return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])
6674

6775

6876
E = TypeVar("E", bound=Enum)

src/zarr/testing/store.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,19 @@ async def test_getsize(self, store: S) -> None:
368368
await self.set(store, key, data)
369369

370370
result = await store.getsize(key)
371-
assert result == 10
371+
assert isinstance(result, int)
372+
assert result > 0
372373

373374
async def test_getsize_raises(self, store: S) -> None:
374375
with pytest.raises(FileNotFoundError):
375376
await store.getsize("not-a-real-key")
377+
378+
async def test_getsize_prefix(self, store: S) -> None:
379+
prefix = "array/c/"
380+
for i in range(10):
381+
data = self.buffer_cls.from_bytes(b"0" * 10)
382+
await self.set(store, f"{prefix}/{i}", data)
383+
384+
result = await store.getsize_prefix(prefix)
385+
assert isinstance(result, int)
386+
assert result > 0

0 commit comments

Comments
 (0)