Skip to content

Commit 50b02b4

Browse files
committed
ensure we always have a root group
1 parent e74445b commit 50b02b4

File tree

2 files changed

+130
-31
lines changed

2 files changed

+130
-31
lines changed

src/zarr/core/group.py

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,8 +1434,7 @@ async def create_hierarchy(
14341434
)
14351435
raise ValueError(msg)
14361436

1437-
# insert ImplicitGroupMetadata to represent self
1438-
nodes_rooted = nodes | {"": ImplicitGroupMarker(zarr_format=self.metadata.zarr_format)}
1437+
nodes_rooted = nodes
14391438

14401439
async for key, node in create_hierarchy(
14411440
store=self.store,
@@ -2913,20 +2912,43 @@ async def create_hierarchy(
29132912

29142913
# empty hierarchies should be a no-op
29152914
if len(nodes_parsed) > 0:
2915+
# figure out which zarr format we are using
2916+
zarr_format = next(iter(nodes_parsed.values())).zarr_format
2917+
2918+
# check which implicit groups will require materialization
2919+
implicit_group_keys = tuple(
2920+
filter(lambda k: isinstance(nodes_parsed[k], ImplicitGroupMarker), nodes_parsed)
2921+
)
2922+
# read potential group metadata for each implicit group
2923+
maybe_extant_group_coros = (
2924+
_read_group_metadata(store, k, zarr_format=zarr_format) for k in implicit_group_keys
2925+
)
2926+
maybe_extant_groups = await asyncio.gather(
2927+
*maybe_extant_group_coros, return_exceptions=True
2928+
)
2929+
2930+
for key, value in zip(implicit_group_keys, maybe_extant_groups, strict=True):
2931+
if isinstance(value, BaseException):
2932+
if isinstance(value, FileNotFoundError):
2933+
# this is fine -- there was no group there, so we will create one
2934+
pass
2935+
else:
2936+
raise value
2937+
else:
2938+
# a loop exists already at ``key``, so we can avoid creating anything there
2939+
redundant_implicit_groups.append(key)
2940+
29162941
if overwrite:
2917-
# only remove elements from the store if they would be overwritten by nodes
2918-
should_delete_keys = (
2919-
k for k, v in nodes_parsed.items() if not isinstance(v, ImplicitGroupMarker)
2920-
)
2921-
await asyncio.gather(
2922-
*(store.delete_dir(key) for key in should_delete_keys), return_exceptions=True
2942+
# we will remove any nodes that collide with arrays and non-implicit groups defined in
2943+
# nodes
2944+
2945+
# track the keys of nodes we need to delete
2946+
to_delete_keys = []
2947+
to_delete_keys.extend(
2948+
[k for k, v in nodes_parsed.items() if k not in implicit_group_keys]
29232949
)
2950+
await asyncio.gather(*(store.delete_dir(key) for key in to_delete_keys))
29242951
else:
2925-
# attempt to fetch all of the metadata described in hierarchy
2926-
# first figure out which zarr format we are dealing with
2927-
sample, *_ = nodes_parsed.values()
2928-
2929-
zarr_format = sample.zarr_format
29302952
# This type is long.
29312953
coros: (
29322954
Generator[Coroutine[Any, Any, ArrayV2Metadata | GroupMetadata], None, None]
@@ -3084,7 +3106,7 @@ def _parse_hierarchy_dict(
30843106
data: Mapping[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
30853107
) -> dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]:
30863108
"""
3087-
Take an input Mapping of str: node pairs, and parse it into
3109+
Take an input with type Mapping[str, ArrayMetadata | GroupMetadata] and parse it into
30883110
a dict of str: node pairs that models a valid, complete Zarr hierarchy.
30893111
30903112
If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups,
@@ -3093,36 +3115,39 @@ def _parse_hierarchy_dict(
30933115
Otherwise, return a dict derived from the input with GroupMetadata inserted as needed to make
30943116
the hierarchy complete.
30953117
3096-
For example, an input of {'a/b/c': ArrayMetadata} is incomplete, because it references two
3097-
groups ('a' and 'a/b') that are not specified in the input. Applying this function
3118+
For example, an input of {'a/b': ArrayMetadata} is incomplete, because it references two
3119+
groups (the root group '' and a group at 'a') that are not specified in the input. Applying this function
30983120
to that input will result in a return value of
3099-
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}, i.e. the implied groups
3121+
{'': GroupMetadata, 'a': GroupMetadata, 'a/b': ArrayMetadata}, i.e. the implied groups
31003122
were added.
31013123
31023124
The input is also checked for the following conditions; an error is raised if any are violated:
31033125
31043126
- No arrays can contain group or arrays (i.e., all arrays must be leaf nodes).
31053127
- All arrays and groups must have the same ``zarr_format`` value.
31063128
3107-
if ``allow_root`` is set to False, then the input is also checked to ensure that it does not
3108-
contain a key that normalizes to the empty string (''), as this is reserved for the root node,
3109-
and in some situations creating a root node is not permitted, for example, when creating a
3110-
hierarchy relative to an existing group.
3111-
31123129
This function ensures that the input is transformed into a specification of a complete and valid
31133130
Zarr hierarchy.
31143131
"""
31153132

3133+
# ensure that all nodes have the same zarr format
31163134
data_purified = _ensure_consistent_zarr_format(data)
31173135

3136+
# ensure that keys are normalized to zarr paths
31183137
data_normed_keys = _normalize_path_keys(data_purified)
31193138

3120-
out: dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {
3121-
**data_normed_keys
3122-
}
3139+
# insert an implicit root group if a root was not specified
3140+
# but not if an empty dict was provided, because any empty hierarchy has no nodes
3141+
if len(data_normed_keys) > 0 and "" not in data_normed_keys:
3142+
z_format = next(iter(data_normed_keys.values())).zarr_format
3143+
data_normed_keys = data_normed_keys | {"": ImplicitGroupMarker(zarr_format=z_format)}
31233144

3124-
for k, v in data.items():
3145+
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed_keys}
3146+
3147+
for k, v in data_normed_keys.items():
31253148
key_split = k.split("/")
3149+
3150+
# get every parent path
31263151
*subpaths, _ = accumulate(key_split, lambda a, b: _join_paths([a, b]))
31273152

31283153
for subpath in subpaths:
@@ -3136,7 +3161,6 @@ def _parse_hierarchy_dict(
31363161
"This is invalid. Only Zarr groups can contain other nodes."
31373162
)
31383163
raise ValueError(msg)
3139-
31403164
return out
31413165

31423166

@@ -3338,6 +3362,34 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
33383362
return _build_metadata_v2(zmeta, zattrs)
33393363

33403364

3365+
async def _read_group_metadata_v2(store: Store, path: str) -> GroupMetadata:
3366+
"""
3367+
Read group metadata or error
3368+
"""
3369+
meta = await _read_metadata_v2(store=store, path=path)
3370+
if not isinstance(meta, GroupMetadata):
3371+
raise FileNotFoundError(f"Group metadata was not found in {store} at {path}")
3372+
return meta
3373+
3374+
3375+
async def _read_group_metadata_v3(store: Store, path: str) -> GroupMetadata:
3376+
"""
3377+
Read group metadata or error
3378+
"""
3379+
meta = await _read_metadata_v3(store=store, path=path)
3380+
if not isinstance(meta, GroupMetadata):
3381+
raise FileNotFoundError(f"Group metadata was not found in {store} at {path}")
3382+
return meta
3383+
3384+
3385+
async def _read_group_metadata(
3386+
store: Store, path: str, *, zarr_format: ZarrFormat
3387+
) -> GroupMetadata:
3388+
if zarr_format == 2:
3389+
return await _read_group_metadata_v2(store=store, path=path)
3390+
return await _read_group_metadata_v3(store=store, path=path)
3391+
3392+
33413393
def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMetadata:
33423394
"""
33433395
Convert a dict representation of Zarr V3 metadata into the corresponding metadata class.

tests/test_group.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
from zarr.core.group import (
2727
ConsolidatedMetadata,
2828
GroupMetadata,
29+
ImplicitGroupMarker,
2930
_build_metadata_v3,
3031
_get_roots,
32+
_parse_hierarchy_dict,
3133
create_hierarchy,
3234
create_nodes,
3335
create_rooted_hierarchy,
@@ -38,7 +40,7 @@
3840
from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError
3941
from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore
4042
from zarr.storage._common import make_store_path
41-
from zarr.storage._utils import _join_paths
43+
from zarr.storage._utils import _join_paths, normalize_path
4244
from zarr.testing.store import LatencyStore
4345

4446
from .conftest import meta_from_array, parse_store
@@ -1651,7 +1653,8 @@ async def test_group_create_hierarchy(
16511653
Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict.
16521654
Also test that off-target nodes are not deleted, and that the root group is not deleted
16531655
"""
1654-
g = Group.from_store(store, zarr_format=zarr_format)
1656+
root_attrs = {"root": True}
1657+
g = Group.from_store(store, zarr_format=zarr_format, attributes=root_attrs)
16551658

16561659
node_spec = {
16571660
"a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}),
@@ -1683,10 +1686,10 @@ async def test_group_create_hierarchy(
16831686
for k, v in extant_spec.items():
16841687
if overwrite:
16851688
assert k in all_members
1686-
# check that we did not erase the root group
1687-
assert sync_group.get_node(store=store, path="", zarr_format=zarr_format) == g
16881689
else:
16891690
assert all_members[k].metadata == v == extant_created[k].metadata
1691+
# ensure that we left the root group as-is
1692+
assert sync_group.get_node(store=store, path="", zarr_format=zarr_format).attrs == root_attrs
16901693

16911694

16921695
@pytest.mark.parametrize("store", ["memory"], indirect=True)
@@ -1707,6 +1710,50 @@ def test_group_create_hierarchy_no_root(
17071710
_ = dict(g.create_hierarchy(tree, overwrite=overwrite))
17081711

17091712

1713+
class TestParseHierarchyDict:
1714+
"""
1715+
Tests for the function that parses dicts of str : Metadata pairs, ensuring that the output models a
1716+
valid Zarr hierarchy
1717+
"""
1718+
1719+
@staticmethod
1720+
def test_normed_keys() -> None:
1721+
"""
1722+
Test that keys get normalized properly
1723+
"""
1724+
1725+
nodes = {
1726+
"a": GroupMetadata(),
1727+
"/b": GroupMetadata(),
1728+
"": GroupMetadata(),
1729+
"/a//c////": GroupMetadata(),
1730+
}
1731+
observed = _parse_hierarchy_dict(data=nodes)
1732+
expected = {normalize_path(k): v for k, v in nodes.items()}
1733+
assert observed == expected
1734+
1735+
@staticmethod
1736+
def test_empty() -> None:
1737+
"""
1738+
Test that an empty dict passes through
1739+
"""
1740+
assert _parse_hierarchy_dict(data={}) == {}
1741+
1742+
@staticmethod
1743+
def test_implicit_groups() -> None:
1744+
"""
1745+
Test that implicit groups were added as needed.
1746+
"""
1747+
requested = {"a/b/c": GroupMetadata()}
1748+
expected = requested | {
1749+
"": ImplicitGroupMarker(),
1750+
"a": ImplicitGroupMarker(),
1751+
"a/b": ImplicitGroupMarker(),
1752+
}
1753+
observed = _parse_hierarchy_dict(data=requested)
1754+
assert observed == expected
1755+
1756+
17101757
@pytest.mark.parametrize("store", ["memory"], indirect=True)
17111758
def test_group_create_hierarchy_invalid_mixed_zarr_format(
17121759
store: Store, zarr_format: ZarrFormat

0 commit comments

Comments
 (0)