Skip to content

Commit 8f87977

Browse files
committed
bolt concurrent members implementation onto async group
1 parent d33cb7d commit 8f87977

File tree

3 files changed

+195
-30
lines changed

3 files changed

+195
-30
lines changed

src/zarr/core/array.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -786,34 +786,30 @@ def path(self) -> str:
786786
return self.store_path.path
787787

788788
@property
789-
def name(self) -> str | None:
789+
def name(self) -> str:
790790
"""Array name following h5py convention.
791791
792792
Returns
793793
-------
794794
str
795795
The name of the array.
796796
"""
797-
if self.path:
798-
# follow h5py convention: add leading slash
799-
name = self.path
800-
if name[0] != "/":
801-
name = "/" + name
802-
return name
803-
return None
797+
# follow h5py convention: add leading slash
798+
name = self.path
799+
if not name.startswith('/'):
800+
name = "/" + name
801+
return name
804802

805803
@property
806-
def basename(self) -> str | None:
804+
def basename(self) -> str:
807805
"""Final component of name.
808806
809807
Returns
810808
-------
811809
str
812810
The basename or final component of the array name.
813811
"""
814-
if self.name is not None:
815-
return self.name.split("/")[-1]
816-
return None
812+
return self.name.split("/")[-1]
817813

818814
@property
819815
def cdata_shape(self) -> ChunkCoords:
@@ -1436,12 +1432,12 @@ def path(self) -> str:
14361432
return self._async_array.path
14371433

14381434
@property
1439-
def name(self) -> str | None:
1435+
def name(self) -> str:
14401436
"""Array name following h5py convention."""
14411437
return self._async_array.name
14421438

14431439
@property
1444-
def basename(self) -> str | None:
1440+
def basename(self) -> str:
14451441
"""Final component of name."""
14461442
return self._async_array.basename
14471443

src/zarr/core/group.py

Lines changed: 123 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from zarr.abc.store import Store, set_or_delete
2020
from zarr.core.array import Array, AsyncArray, _build_parents
2121
from zarr.core.attributes import Attributes
22-
from zarr.core.buffer import default_buffer_prototype
22+
from zarr.core.buffer import default_buffer_prototype, Buffer
2323
from zarr.core.common import (
2424
JSON,
2525
ZARR_JSON,
@@ -1151,18 +1151,18 @@ async def members(
11511151
"""
11521152
if max_depth is not None and max_depth < 0:
11531153
raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead")
1154-
async for item in self._members(max_depth=max_depth, current_depth=0):
1154+
async for item in self._members(max_depth=max_depth):
11551155
yield item
11561156

1157-
async def _members(
1157+
async def _members_old(
11581158
self, max_depth: int | None, current_depth: int
11591159
) -> AsyncGenerator[
11601160
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
11611161
None,
11621162
]:
11631163
if self.metadata.consolidated_metadata is not None:
11641164
# we should be able to do members without any additional I/O
1165-
members = self._members_consolidated(max_depth, current_depth)
1165+
members = self._members_consolidated(max_depth)
11661166
for member in members:
11671167
yield member
11681168
return
@@ -1202,8 +1202,7 @@ async def _members(
12021202
# implies an AsyncGroup, not an AsyncArray
12031203
assert isinstance(obj, AsyncGroup)
12041204
async for child_key, val in obj._members(
1205-
max_depth=max_depth, current_depth=current_depth + 1
1206-
):
1205+
max_depth=max_depth):
12071206
yield f"{key}/{child_key}", val
12081207
except KeyError:
12091208
# keyerror is raised when `key` names an object (in the object storage sense),
@@ -1216,12 +1215,14 @@ async def _members(
12161215
)
12171216

12181217
def _members_consolidated(
1219-
self, max_depth: int | None, current_depth: int, prefix: str = ""
1218+
self, max_depth: int | None, prefix: str = ""
12201219
) -> Generator[
12211220
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
12221221
None,
12231222
]:
12241223
consolidated_metadata = self.metadata.consolidated_metadata
1224+
1225+
do_recursion = max_depth is None or max_depth > 0
12251226

12261227
# we kind of just want the top-level keys.
12271228
if consolidated_metadata is not None:
@@ -1232,10 +1233,43 @@ def _members_consolidated(
12321233
key = f"{prefix}/{key}".lstrip("/")
12331234
yield key, obj
12341235

1235-
if ((max_depth is None) or (current_depth < max_depth)) and isinstance(
1236+
if do_recursion and isinstance(
12361237
obj, AsyncGroup
12371238
):
1238-
yield from obj._members_consolidated(max_depth, current_depth + 1, prefix=key)
1239+
if max_depth is None:
1240+
new_depth = None
1241+
else:
1242+
new_depth = max_depth - 1
1243+
yield from obj._members_consolidated(new_depth, prefix=key)
1244+
1245+
async def _members(
1246+
self,
1247+
max_depth: int | None) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1248+
skip_keys: tuple[str, ...]
1249+
if self.metadata.zarr_format == 2:
1250+
skip_keys = ('.zattrs', '.zgroup','.zarray', '.zmetadata')
1251+
elif self.metadata.zarr_format == 3:
1252+
skip_keys = ('zarr.json',)
1253+
else:
1254+
raise ValueError(f"Unknown Zarr format: {self.metadata.zarr_format}")
1255+
1256+
if self.metadata.consolidated_metadata is not None:
1257+
members = self._members_consolidated(max_depth=max_depth)
1258+
for member in members:
1259+
yield member
1260+
return
1261+
1262+
if not self.store_path.store.supports_listing:
1263+
msg = (
1264+
f"The store associated with this group ({type(self.store_path.store)}) "
1265+
"does not support listing, "
1266+
"specifically via the `list_dir` method. "
1267+
"This function requires a store that supports listing."
1268+
)
1269+
1270+
raise ValueError(msg)
1271+
async for member in iter_members_deep(self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys):
1272+
yield member
12391273

12401274
async def keys(self) -> AsyncGenerator[str, None]:
12411275
async for key, _ in self.members():
@@ -1848,10 +1882,13 @@ def array(
18481882
)
18491883

18501884

1851-
async def members_v3(
1885+
async def members_recursive(
18521886
store: Store,
18531887
path: str,
18541888
) -> Any:
1889+
"""
1890+
Recursively fetch all members of a group.
1891+
"""
18551892
metadata_keys = ("zarr.json",)
18561893

18571894
members_flat: tuple[tuple[str, ArrayV3Metadata | GroupMetadata], ...] = ()
@@ -1879,18 +1916,88 @@ async def members_v3(
18791916
resolved_metadata = resolve_metadata_v3(blob.to_bytes())
18801917
members_flat += ((key_body, resolved_metadata),)
18811918
if isinstance(resolved_metadata, GroupMetadata):
1882-
to_recurse.append(members_v3(store, key_body))
1883-
1884-
# for r in to_recurse:
1885-
# members_flat += await r
1919+
to_recurse.append(
1920+
members_recursive(store, key_body))
18861921

18871922
subgroups = await asyncio.gather(*to_recurse)
18881923
members_flat += tuple(subgroup for subgroup in subgroups)
18891924

1890-
# recurse for groups
1891-
18921925
return members_flat
18931926

1927+
async def iter_members(
1928+
node: AsyncGroup,
1929+
skip_keys: tuple[str, ...]
1930+
) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1931+
"""
1932+
Iterate over the arrays and groups contained in a group.
1933+
"""
1934+
1935+
# retrieve keys from storage
1936+
keys = [key async for key in node.store.list_dir(node.path)]
1937+
keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys))
1938+
1939+
node_tasks = tuple(asyncio.create_task(
1940+
node.getitem(key), name=key) for key in keys_filtered)
1941+
1942+
for fetched_node_coro in asyncio.as_completed(node_tasks):
1943+
try:
1944+
fetched_node = await fetched_node_coro
1945+
except KeyError as e:
1946+
# keyerror is raised when `key` names an object (in the object storage sense),
1947+
# as opposed to a prefix, in the store under the prefix associated with this group
1948+
# in which case `key` cannot be the name of a sub-array or sub-group.
1949+
warnings.warn(
1950+
f"Object at {e.args[0]} is not recognized as a component of a Zarr hierarchy.",
1951+
UserWarning,
1952+
stacklevel=1,
1953+
)
1954+
continue
1955+
match fetched_node:
1956+
case AsyncArray() | AsyncGroup():
1957+
yield fetched_node.basename, fetched_node
1958+
case _:
1959+
raise ValueError(f"Unexpected type: {type(fetched_node)}")
1960+
1961+
async def iter_members_deep(
1962+
group: AsyncGroup,
1963+
*,
1964+
prefix: str,
1965+
max_depth: int | None,
1966+
skip_keys: tuple[str, ...]
1967+
) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1968+
"""
1969+
Iterate over the arrays and groups contained in a group, and optionally the
1970+
arrays and groups contained in those groups.
1971+
"""
1972+
1973+
to_recurse = []
1974+
do_recursion = max_depth is None or max_depth > 0
1975+
if max_depth is None:
1976+
new_depth = None
1977+
else:
1978+
new_depth = max_depth - 1
1979+
1980+
async for name, node in iter_members(group, skip_keys=skip_keys):
1981+
yield f'{prefix}/{name}'.lstrip('/'), node
1982+
if isinstance(node, AsyncGroup) and do_recursion:
1983+
to_recurse.append(iter_members_deep(
1984+
node,
1985+
max_depth=new_depth,
1986+
prefix=f'{prefix}/{name}',
1987+
skip_keys=skip_keys))
1988+
1989+
for subgroup in to_recurse:
1990+
async for name, node in subgroup:
1991+
yield name, node
1992+
1993+
1994+
def resolve_metadata_v2(blobs: tuple[str | bytes | bytearray, str | bytes | bytearray]) -> ArrayV2Metadata | GroupMetadata:
1995+
zarr_metadata = json.loads(blobs[0])
1996+
attrs = json.loads(blobs[1])
1997+
if 'shape' in zarr_metadata:
1998+
return ArrayV2Metadata.from_dict(zarr_metadata | {'attrs': attrs})
1999+
else:
2000+
return GroupMetadata.from_dict(zarr_metadata | {'attrs': attrs})
18942001

18952002
def resolve_metadata_v3(blob: str | bytes | bytearray) -> ArrayV3Metadata | GroupMetadata:
18962003
zarr_json = json.loads(blob)

tests/test_group.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,68 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad
207207
with pytest.raises(ValueError, match="max_depth"):
208208
members_observed = group.members(max_depth=-1)
209209

210+
def test_group_members_2(store: Store, zarr_format: ZarrFormat) -> None:
211+
"""
212+
Test that `Group.members` returns correct values, i.e. the arrays and groups
213+
(explicit and implicit) contained in that group.
214+
"""
215+
# group/
216+
# subgroup/
217+
# subsubgroup/
218+
# subsubsubgroup
219+
# subarray
220+
221+
path = "group"
222+
group = Group.from_store(
223+
store=store,
224+
zarr_format=zarr_format,
225+
)
226+
members_expected: dict[str, Array | Group] = {}
227+
228+
members_expected["subgroup"] = group.create_group("subgroup")
229+
# make a sub-sub-subgroup, to ensure that the children calculation doesn't go
230+
# too deep in the hierarchy
231+
subsubgroup = members_expected["subgroup"].create_group("subsubgroup")
232+
subsubsubgroup = subsubgroup.create_group("subsubsubgroup")
233+
234+
members_expected["subarray"] = group.create_array(
235+
"subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True
236+
)
237+
238+
# add an extra object to the domain of the group.
239+
# the list of children should ignore this object.
240+
sync(
241+
store.set(
242+
f"{path}/extra_object-1",
243+
default_buffer_prototype().buffer.from_bytes(b"000000"),
244+
)
245+
)
246+
# add an extra object under a directory-like prefix in the domain of the group.
247+
# this creates a directory with a random key in it
248+
# this should not show up as a member
249+
sync(
250+
store.set(
251+
f"{path}/extra_directory/extra_object-2",
252+
default_buffer_prototype().buffer.from_bytes(b"000000"),
253+
)
254+
)
255+
256+
# this warning shows up when extra objects show up in the hierarchy
257+
warn_context = pytest.warns(
258+
UserWarning, match=r"Object at .* is not recognized as a component of a Zarr hierarchy."
259+
)
260+
261+
with warn_context:
262+
members_observed = group.members()
263+
# members are not guaranteed to be ordered, so sort before comparing
264+
assert sorted(dict(members_observed)) == sorted(members_expected)
265+
266+
# partial
267+
with warn_context:
268+
members_observed = group.members(max_depth=1)
269+
members_expected["subgroup/subsubgroup"] = subsubgroup
270+
# members are not guaranteed to be ordered, so sort before comparing
271+
assert sorted(dict(members_observed)) == sorted(members_expected)
210272

211273
def test_group(store: Store, zarr_format: ZarrFormat) -> None:
212274
"""

0 commit comments

Comments
 (0)