Skip to content

Commit 502ad5e

Browse files
committed
use metadata / node builders for v3 node creation
1 parent 87e0b83 commit 502ad5e

File tree

1 file changed

+84
-114
lines changed

1 file changed

+84
-114
lines changed

src/zarr/core/group.py

Lines changed: 84 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from zarr.abc.store import Store, set_or_delete
2020
from zarr.core.array import Array, AsyncArray, _build_parents
2121
from zarr.core.attributes import Attributes
22-
from zarr.core.buffer import default_buffer_prototype, Buffer
22+
from zarr.core.buffer import Buffer, default_buffer_prototype
2323
from zarr.core.common import (
2424
JSON,
2525
ZARR_JSON,
@@ -645,12 +645,10 @@ async def getitem(
645645
raise KeyError(key)
646646
else:
647647
zarr_json = json.loads(zarr_json_bytes.to_bytes())
648-
if zarr_json["node_type"] == "group":
649-
return type(self).from_dict(store_path, zarr_json)
650-
elif zarr_json["node_type"] == "array":
651-
return AsyncArray.from_dict(store_path, zarr_json)
652-
else:
653-
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
648+
metadata = build_metadata_v3(zarr_json)
649+
node = build_node_v3(metadata, store_path)
650+
return node
651+
654652
elif self.metadata.zarr_format == 2:
655653
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
656654
# This guarantees that we will always make at least one extra request to the store
@@ -1154,74 +1152,14 @@ async def members(
11541152
async for item in self._members(max_depth=max_depth):
11551153
yield item
11561154

1157-
async def _members_old(
1158-
self, max_depth: int | None, current_depth: int
1159-
) -> AsyncGenerator[
1160-
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
1161-
None,
1162-
]:
1163-
if self.metadata.consolidated_metadata is not None:
1164-
# we should be able to do members without any additional I/O
1165-
members = self._members_consolidated(max_depth)
1166-
for member in members:
1167-
yield member
1168-
return
1169-
1170-
if not self.store_path.store.supports_listing:
1171-
msg = (
1172-
f"The store associated with this group ({type(self.store_path.store)}) "
1173-
"does not support listing, "
1174-
"specifically via the `list_dir` method. "
1175-
"This function requires a store that supports listing."
1176-
)
1177-
1178-
raise ValueError(msg)
1179-
# would be nice to make these special keys accessible programmatically,
1180-
# and scoped to specific zarr versions
1181-
# especially true for `.zmetadata` which is configurable
1182-
_skip_keys = ("zarr.json", ".zgroup", ".zattrs", ".zmetadata")
1183-
1184-
# hmm lots of I/O and logic interleaved here.
1185-
# We *could* have an async gen over self.metadata.consolidated_metadata.metadata.keys()
1186-
# and plug in here. `getitem` will skip I/O.
1187-
# Kinda a shame to have all the asyncio task overhead though, when it isn't needed.
1188-
1189-
async for key in self.store_path.store.list_dir(self.store_path.path):
1190-
if key in _skip_keys:
1191-
continue
1192-
try:
1193-
obj = await self.getitem(key)
1194-
yield (key, obj)
1195-
1196-
if (
1197-
((max_depth is None) or (current_depth < max_depth))
1198-
and hasattr(obj.metadata, "node_type")
1199-
and obj.metadata.node_type == "group"
1200-
):
1201-
# the assert is just for mypy to know that `obj.metadata.node_type`
1202-
# implies an AsyncGroup, not an AsyncArray
1203-
assert isinstance(obj, AsyncGroup)
1204-
async for child_key, val in obj._members(
1205-
max_depth=max_depth):
1206-
yield f"{key}/{child_key}", val
1207-
except KeyError:
1208-
# keyerror is raised when `key` names an object (in the object storage sense),
1209-
# as opposed to a prefix, in the store under the prefix associated with this group
1210-
# in which case `key` cannot be the name of a sub-array or sub-group.
1211-
warnings.warn(
1212-
f"Object at {key} is not recognized as a component of a Zarr hierarchy.",
1213-
UserWarning,
1214-
stacklevel=1,
1215-
)
1216-
12171155
def _members_consolidated(
12181156
self, max_depth: int | None, prefix: str = ""
12191157
) -> Generator[
12201158
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
12211159
None,
12221160
]:
12231161
consolidated_metadata = self.metadata.consolidated_metadata
1224-
1162+
12251163
do_recursion = max_depth is None or max_depth > 0
12261164

12271165
# we kind of just want the top-level keys.
@@ -1233,23 +1171,23 @@ def _members_consolidated(
12331171
key = f"{prefix}/{key}".lstrip("/")
12341172
yield key, obj
12351173

1236-
if do_recursion and isinstance(
1237-
obj, AsyncGroup
1238-
):
1174+
if do_recursion and isinstance(obj, AsyncGroup):
12391175
if max_depth is None:
1240-
new_depth = None
1176+
new_depth = None
12411177
else:
12421178
new_depth = max_depth - 1
12431179
yield from obj._members_consolidated(new_depth, prefix=key)
1244-
1180+
12451181
async def _members(
1246-
self,
1247-
max_depth: int | None) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1182+
self, max_depth: int | None
1183+
) -> AsyncGenerator[
1184+
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
1185+
]:
12481186
skip_keys: tuple[str, ...]
12491187
if self.metadata.zarr_format == 2:
1250-
skip_keys = ('.zattrs', '.zgroup','.zarray', '.zmetadata')
1188+
skip_keys = (".zattrs", ".zgroup", ".zarray", ".zmetadata")
12511189
elif self.metadata.zarr_format == 3:
1252-
skip_keys = ('zarr.json',)
1190+
skip_keys = ("zarr.json",)
12531191
else:
12541192
raise ValueError(f"Unknown Zarr format: {self.metadata.zarr_format}")
12551193

@@ -1268,7 +1206,9 @@ async def _members(
12681206
)
12691207

12701208
raise ValueError(msg)
1271-
async for member in iter_members_deep(self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys):
1209+
async for member in iter_members_deep(
1210+
self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys
1211+
):
12721212
yield member
12731213

12741214
async def keys(self) -> AsyncGenerator[str, None]:
@@ -1913,31 +1853,31 @@ async def members_recursive(
19131853
key_body = "/".join(key.split("/")[:-1])
19141854

19151855
if blob is not None:
1916-
resolved_metadata = resolve_metadata_v3(blob.to_bytes())
1856+
resolved_metadata = build_metadata_v3(blob.to_bytes())
19171857
members_flat += ((key_body, resolved_metadata),)
19181858
if isinstance(resolved_metadata, GroupMetadata):
1919-
to_recurse.append(
1920-
members_recursive(store, key_body))
1859+
to_recurse.append(members_recursive(store, key_body))
19211860

19221861
subgroups = await asyncio.gather(*to_recurse)
19231862
members_flat += tuple(subgroup for subgroup in subgroups)
19241863

19251864
return members_flat
19261865

1866+
19271867
async def iter_members(
1928-
node: AsyncGroup,
1929-
skip_keys: tuple[str, ...]
1930-
) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1868+
node: AsyncGroup, skip_keys: tuple[str, ...]
1869+
) -> AsyncGenerator[
1870+
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
1871+
]:
19311872
"""
19321873
Iterate over the arrays and groups contained in a group.
19331874
"""
1934-
1875+
19351876
# retrieve keys from storage
19361877
keys = [key async for key in node.store.list_dir(node.path)]
19371878
keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys))
19381879

1939-
node_tasks = tuple(asyncio.create_task(
1940-
node.getitem(key), name=key) for key in keys_filtered)
1880+
node_tasks = tuple(asyncio.create_task(node.getitem(key), name=key) for key in keys_filtered)
19411881

19421882
for fetched_node_coro in asyncio.as_completed(node_tasks):
19431883
try:
@@ -1958,15 +1898,14 @@ async def iter_members(
19581898
case _:
19591899
raise ValueError(f"Unexpected type: {type(fetched_node)}")
19601900

1901+
19611902
async def iter_members_deep(
1962-
group: AsyncGroup,
1963-
*,
1964-
prefix: str,
1965-
max_depth: int | None,
1966-
skip_keys: tuple[str, ...]
1967-
) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]:
1903+
group: AsyncGroup, *, prefix: str, max_depth: int | None, skip_keys: tuple[str, ...]
1904+
) -> AsyncGenerator[
1905+
tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None
1906+
]:
19681907
"""
1969-
Iterate over the arrays and groups contained in a group, and optionally the
1908+
Iterate over the arrays and groups contained in a group, and optionally the
19701909
arrays and groups contained in those groups.
19711910
"""
19721911

@@ -1978,34 +1917,65 @@ async def iter_members_deep(
19781917
new_depth = max_depth - 1
19791918

19801919
async for name, node in iter_members(group, skip_keys=skip_keys):
1981-
yield f'{prefix}/{name}'.lstrip('/'), node
1920+
yield f"{prefix}/{name}".lstrip("/"), node
19821921
if isinstance(node, AsyncGroup) and do_recursion:
1983-
to_recurse.append(iter_members_deep(
1984-
node,
1985-
max_depth=new_depth,
1986-
prefix=f'{prefix}/{name}',
1987-
skip_keys=skip_keys))
1922+
to_recurse.append(
1923+
iter_members_deep(
1924+
node, max_depth=new_depth, prefix=f"{prefix}/{name}", skip_keys=skip_keys
1925+
)
1926+
)
19881927

19891928
for subgroup in to_recurse:
19901929
async for name, node in subgroup:
19911930
yield name, node
1992-
19931931

1994-
def resolve_metadata_v2(blobs: tuple[str | bytes | bytearray, str | bytes | bytearray]) -> ArrayV2Metadata | GroupMetadata:
1932+
1933+
def resolve_metadata_v2(
1934+
blobs: tuple[str | bytes | bytearray, str | bytes | bytearray],
1935+
) -> ArrayV2Metadata | GroupMetadata:
19951936
zarr_metadata = json.loads(blobs[0])
19961937
attrs = json.loads(blobs[1])
1997-
if 'shape' in zarr_metadata:
1998-
return ArrayV2Metadata.from_dict(zarr_metadata | {'attrs': attrs})
1938+
if "shape" in zarr_metadata:
1939+
return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs})
19991940
else:
2000-
return GroupMetadata.from_dict(zarr_metadata | {'attrs': attrs})
1941+
return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs})
1942+
20011943

2002-
def resolve_metadata_v3(blob: str | bytes | bytearray) -> ArrayV3Metadata | GroupMetadata:
2003-
zarr_json = json.loads(blob)
1944+
def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata:
1945+
"""
1946+
Take a dict and convert it into the correct metadata type.
1947+
"""
20041948
if "node_type" not in zarr_json:
2005-
raise ValueError("missing node_type in metadata document")
2006-
if zarr_json["node_type"] == "array":
2007-
return ArrayV3Metadata.from_dict(zarr_json)
2008-
elif zarr_json["node_type"] == "group":
2009-
return GroupMetadata.from_dict(zarr_json)
2010-
else:
2011-
raise ValueError("invalid node_type in metadata document")
1949+
raise KeyError("missing `node_type` key in metadata document.")
1950+
match zarr_json:
1951+
case {"node_type": "array"}:
1952+
return ArrayV3Metadata.from_dict(zarr_json)
1953+
case {"node_type": "group"}:
1954+
return GroupMetadata.from_dict(zarr_json)
1955+
case _:
1956+
raise ValueError("invalid value for `node_type` key in metadata document")
1957+
1958+
1959+
def build_metadata_v2(
1960+
zarr_json: dict[str, Any], attrs_json: dict[str, Any]
1961+
) -> ArrayV2Metadata | GroupMetadata:
1962+
match zarr_json:
1963+
case {"shape": _}:
1964+
return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json})
1965+
case _:
1966+
return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json})
1967+
1968+
1969+
def build_node_v3(
1970+
metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath
1971+
) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
1972+
"""
1973+
Take a metadata object and return a node (AsyncArray or AsyncGroup).
1974+
"""
1975+
match metadata:
1976+
case ArrayV3Metadata():
1977+
return AsyncArray(metadata, store_path=store_path)
1978+
case GroupMetadata():
1979+
return AsyncGroup(metadata, store_path=store_path)
1980+
case _:
1981+
raise ValueError(f"Unexpected metadata type: {type(metadata)}")

0 commit comments

Comments
 (0)