Skip to content

Commit 787d6bf

Browse files
committed
add path normalization routines
1 parent 036fd2a commit 787d6bf

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

src/zarr/core/group.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from zarr.errors import MetadataValidationError
6363
from zarr.storage import StoreLike, StorePath
6464
from zarr.storage._common import ensure_no_existing_node, make_store_path
65+
from zarr.storage._utils import normalize_path
6566

6667
if TYPE_CHECKING:
6768
from collections.abc import (
@@ -2961,7 +2962,8 @@ def _get_roots(
29612962
data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
29622963
) -> tuple[str, ...]:
29632964
"""
2964-
Return the keys of the root(s) of the hierarchy
2965+
Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of
2966+
path segments.
29652967
"""
29662968
keys_split = sorted((key.split("/") for key in data), key=len)
29672969
groups: defaultdict[int, list[str]] = defaultdict(list)
@@ -2978,8 +2980,8 @@ def _parse_hierarchy_dict(
29782980
then return an identical copy of that dict. Otherwise, return a version of the input dict
29792981
with groups added where they are needed to make the hierarchy explicit.
29802982
2981-
For example, an input of {'a/b/c': ...} will result in a return value of
2982-
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
2983+
For example, an input of {'a/b/c': ArrayMetadata} will result in a return value of
2984+
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}.
29832985
29842986
The input is also checked for the following conditions, and an error is raised if any
29852987
of them are violated:
@@ -2990,8 +2992,6 @@ def _parse_hierarchy_dict(
29902992
This function ensures that the input is transformed into a specification of a complete and valid
29912993
Zarr hierarchy.
29922994
"""
2993-
# Create a copy of the input dict
2994-
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data}
29952995

29962996
observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []}
29972997

@@ -3007,16 +3007,15 @@ def _parse_hierarchy_dict(
30073007
f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}."
30083008
"Ensure that all nodes have the same Zarr format."
30093009
)
3010-
30113010
raise ValueError(msg)
30123011

3012+
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {}
3013+
30133014
for k, v in data.items():
30143015
# TODO: ensure that the key is a valid path
3015-
# Split the key into its path components
3016-
key_split = k.split("/")
30173016

3018-
# Iterate over the intermediate path components
3019-
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
3017+
*subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b]))
3018+
30203019
for subpath in subpaths:
30213020
# If a component is not already in the output dict, add a group
30223021
if subpath not in out:
@@ -3032,6 +3031,35 @@ def _parse_hierarchy_dict(
30323031
return out
30333032

30343033

3034+
def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]:
3035+
"""
3036+
Normalize the input paths according to the normalization scheme used for zarr node paths.
3037+
If any two paths normalize to the same value, raise a ValueError.
3038+
"""
3039+
path_map: dict[str, str] = {}
3040+
for path in paths:
3041+
parsed = normalize_path(path)
3042+
if parsed in path_map:
3043+
msg = (
3044+
f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. "
3045+
f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. "
3046+
f"You should use either '{path}' or '{path_map[parsed]}', but not both."
3047+
)
3048+
raise ValueError(msg)
3049+
path_map[parsed] = path
3050+
return tuple(path_map.keys())
3051+
3052+
3053+
def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]:
3054+
"""
3055+
Normalize the keys of the input dict according to the normalization scheme used for zarr node
3056+
paths. If any two keys in the input normalize to the value, raise a ValueError. Return the
3057+
values of data with the normalized keys.
3058+
"""
3059+
parsed_keys = _normalize_paths(data.keys())
3060+
return dict(zip(parsed_keys, data.values(), strict=False))
3061+
3062+
30353063
async def _getitem_semaphore(
30363064
node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None
30373065
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup:

tests/test_group.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
ConsolidatedMetadata,
2626
GroupMetadata,
2727
_from_flat,
28+
_normalize_path_keys,
29+
_normalize_paths,
2830
create_hierarchy,
2931
create_nodes,
3032
)
3133
from zarr.core.sync import _collect_aiterator, sync
3234
from zarr.errors import ContainsArrayError, ContainsGroupError
3335
from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore
3436
from zarr.storage._common import make_store_path
37+
from zarr.storage._utils import normalize_path
3538
from zarr.testing.store import LatencyStore
3639

3740
from .conftest import meta_from_array, parse_store
@@ -1615,6 +1618,37 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s
16151618
assert members_observed_meta == members_expected_meta_relative
16161619

16171620

1621+
@pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")])
1622+
def test_normalize_paths_invalid(paths: tuple[str, str]):
1623+
"""
1624+
Ensure that calling _normalize_paths on values that will normalize to the same value
1625+
will generate a ValueError.
1626+
"""
1627+
a, b = paths
1628+
msg = f"After normalization, the value '{b}' collides with '{a}'. "
1629+
with pytest.raises(ValueError, match=msg):
1630+
_normalize_paths(paths)
1631+
1632+
1633+
@pytest.mark.parametrize(
1634+
"paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")]
1635+
)
1636+
def test_normalize_paths_valid(paths: tuple[str, str]):
1637+
"""
1638+
Ensure that calling _normalize_paths on values that normalize to distinct values
1639+
returns a tuple of those normalized values.
1640+
"""
1641+
expected = tuple(map(normalize_path, paths))
1642+
assert _normalize_paths(paths) == expected
1643+
1644+
1645+
def test_normalize_path_keys():
1646+
data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None}
1647+
observed = _normalize_path_keys(data)
1648+
expected = {normalize_path(k): v for k, v in data.items()}
1649+
assert observed == expected
1650+
1651+
16181652
@pytest.mark.parametrize("store", ["memory"], indirect=True)
16191653
def test_group_members_performance(store: Store) -> None:
16201654
"""

0 commit comments

Comments
 (0)