Skip to content

Commit d10d805

Browse files
committed
fix key/name handling in recursion
1 parent 70a4ff5 commit d10d805

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

src/zarr/api/asynchronous.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,12 @@ async def consolidate_metadata(
186186
group.store_path.store._check_writable()
187187

188188
members_metadata = {k: v.metadata async for k, v in group.members(max_depth=None)}
189-
190189
# While consolidating, we want to be explicit about when child groups
191190
# are empty by inserting an empty dict for consolidated_metadata.metadata
192191
for k, v in members_metadata.items():
193192
if isinstance(v, GroupMetadata) and v.consolidated_metadata is None:
194193
v = dataclasses.replace(v, consolidated_metadata=ConsolidatedMetadata(metadata={}))
195194
members_metadata[k] = v
196-
197195
ConsolidatedMetadata._flat_to_nested(members_metadata)
198196

199197
consolidated_metadata = ConsolidatedMetadata(metadata=members_metadata)

src/zarr/core/group.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ async def getitem(
650650
"""
651651
store_path = self.store_path / key
652652
logger.debug("key=%s, store_path=%s", key, store_path)
653+
metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata
653654

654655
# Consolidated metadata lets us avoid some I/O operations so try that first.
655656
if self.metadata.consolidated_metadata is not None:
@@ -666,8 +667,8 @@ async def getitem(
666667
raise KeyError(key)
667668
else:
668669
zarr_json = json.loads(zarr_json_bytes.to_bytes())
669-
metadata = build_metadata_v3(zarr_json)
670-
return build_node_v3(metadata, store_path)
670+
metadata = _build_metadata_v3(zarr_json)
671+
return _build_node_v3(metadata, store_path)
671672

672673
elif self.metadata.zarr_format == 2:
673674
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
@@ -683,21 +684,16 @@ async def getitem(
683684

684685
# unpack the zarray, if this is None then we must be opening a group
685686
zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None
687+
zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None
686688
# unpack the zattrs, this can be None if no attrs were written
687689
zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {}
688690

689691
if zarray is not None:
690-
# TODO: update this once the V2 array support is part of the primary array class
691-
zarr_json = {**zarray, "attributes": zattrs}
692-
return AsyncArray.from_dict(store_path, zarr_json)
692+
metadata = _build_metadata_v2(zarray, zattrs)
693+
return _build_node_v2(metadata=metadata, store_path=store_path)
693694
else:
694-
zgroup = (
695-
json.loads(zgroup_bytes.to_bytes())
696-
if zgroup_bytes is not None
697-
else {"zarr_format": self.metadata.zarr_format}
698-
)
699-
zarr_json = {**zgroup, "attributes": zattrs}
700-
return type(self).from_dict(store_path, zarr_json)
695+
metadata = _build_metadata_v2(zgroup, zattrs)
696+
return _build_node_v2(metadata=metadata, store_path=store_path)
701697
else:
702698
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
703699

@@ -1332,9 +1328,7 @@ async def _members(
13321328
)
13331329

13341330
raise ValueError(msg)
1335-
async for member in iter_members_deep(
1336-
self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys
1337-
):
1331+
async for member in _iter_members_deep(self, max_depth=max_depth, skip_keys=skip_keys):
13381332
yield member
13391333

13401334
async def keys(self) -> AsyncGenerator[str, None]:
@@ -2633,7 +2627,7 @@ async def members_recursive(
26332627
key_body = "/".join(key.split("/")[:-1])
26342628

26352629
if blob is not None:
2636-
resolved_metadata = build_metadata_v3(json.loads(blob.to_bytes()))
2630+
resolved_metadata = _build_metadata_v3(json.loads(blob.to_bytes()))
26372631
members_flat += ((key_body, resolved_metadata),)
26382632
if isinstance(resolved_metadata, GroupMetadata):
26392633
to_recurse.append(members_recursive(store, key_body))
@@ -2679,8 +2673,8 @@ async def iter_members(
26792673
raise ValueError(f"Unexpected type: {type(fetched_node)}")
26802674

26812675

2682-
async def iter_members_deep(
2683-
group: AsyncGroup, *, prefix: str, max_depth: int | None, skip_keys: tuple[str, ...]
2676+
async def _iter_members_deep(
2677+
group: AsyncGroup, *, max_depth: int | None, skip_keys: tuple[str, ...]
26842678
) -> AsyncGenerator[
26852679
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
26862680
]:
@@ -2689,28 +2683,25 @@ async def iter_members_deep(
26892683
arrays and groups contained in those groups.
26902684
"""
26912685

2692-
to_recurse = []
2686+
to_recurse = {}
26932687
do_recursion = max_depth is None or max_depth > 0
2688+
26942689
if max_depth is None:
26952690
new_depth = None
26962691
else:
26972692
new_depth = max_depth - 1
2698-
26992693
async for name, node in iter_members(group, skip_keys=skip_keys):
2700-
yield f"{prefix}/{name}".lstrip("/"), node
2694+
yield name, node
27012695
if isinstance(node, AsyncGroup) and do_recursion:
2702-
to_recurse.append(
2703-
iter_members_deep(
2704-
node, max_depth=new_depth, prefix=f"{prefix}/{name}", skip_keys=skip_keys
2705-
)
2706-
)
2696+
to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys)
27072697

2708-
for subgroup in to_recurse:
2709-
async for name, node in subgroup:
2710-
yield name, node
2698+
for prefix, subgroup_iter in to_recurse.items():
2699+
async for name, node in subgroup_iter:
2700+
key = f"{prefix}/{name}".lstrip("/")
2701+
yield key, node
27112702

27122703

2713-
def resolve_metadata_v2(
2704+
def _resolve_metadata_v2(
27142705
blobs: tuple[str | bytes | bytearray, str | bytes | bytearray],
27152706
) -> ArrayV2Metadata | GroupMetadata:
27162707
zarr_metadata = json.loads(blobs[0])
@@ -2721,7 +2712,7 @@ def resolve_metadata_v2(
27212712
return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs})
27222713

27232714

2724-
def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata:
2715+
def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata:
27252716
"""
27262717
Take a dict and convert it into the correct metadata type.
27272718
"""
@@ -2736,17 +2727,20 @@ def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetad
27362727
raise ValueError("invalid value for `node_type` key in metadata document")
27372728

27382729

2739-
def build_metadata_v2(
2730+
def _build_metadata_v2(
27402731
zarr_json: dict[str, Any], attrs_json: dict[str, Any]
27412732
) -> ArrayV2Metadata | GroupMetadata:
2733+
"""
2734+
Take a dict and convert it into the correct metadata type.
2735+
"""
27422736
match zarr_json:
27432737
case {"shape": _}:
27442738
return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json})
27452739
case _:
27462740
return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json})
27472741

27482742

2749-
def build_node_v3(
2743+
def _build_node_v3(
27502744
metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath
27512745
) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
27522746
"""
@@ -2759,3 +2753,19 @@ def build_node_v3(
27592753
return AsyncGroup(metadata, store_path=store_path)
27602754
case _:
27612755
raise ValueError(f"Unexpected metadata type: {type(metadata)}")
2756+
2757+
2758+
def _build_node_v2(
2759+
metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath
2760+
) -> AsyncArray[ArrayV2Metadata] | AsyncGroup:
2761+
"""
2762+
Take a metadata object and return a node (AsyncArray or AsyncGroup).
2763+
"""
2764+
2765+
match metadata:
2766+
case ArrayV2Metadata():
2767+
return AsyncArray(metadata, store_path=store_path)
2768+
case GroupMetadata():
2769+
return AsyncGroup(metadata, store_path=store_path)
2770+
case _:
2771+
raise ValueError(f"Unexpected metadata type: {type(metadata)}")

0 commit comments

Comments
 (0)