Skip to content

Commit d07435b

Browse files
committed
use _join_paths for safer path concatenation
1 parent 036fd2a commit d07435b

File tree

2 files changed

+120
-61
lines changed

2 files changed

+120
-61
lines changed

src/zarr/core/group.py

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections import defaultdict
1010
from dataclasses import asdict, dataclass, field, fields, replace
1111
from itertools import accumulate
12-
from pathlib import PurePosixPath
1312
from typing import (
1413
TYPE_CHECKING,
1514
Literal,
@@ -2079,7 +2078,9 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
20792078

20802079
def create_hierarchy(
20812080
self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata]
2082-
) -> dict[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
2081+
) -> Iterator[
2082+
tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]
2083+
]:
20832084
"""
20842085
Create a hierarchy of arrays or groups rooted at this group.
20852086
@@ -2097,6 +2098,14 @@ def create_hierarchy(
20972098
-------
20982099
A dict containing the created nodes, with the same keys as the input
20992100
"""
2101+
if "" in nodes:
2102+
msg = (
2103+
"Found the key '' in nodes, which denotes the root group. Creating the root group "
2104+
"from an existing group is not supported. If you want to create an entire Zarr group, "
2105+
"including the root group, from a dict then use the _from_flat method."
2106+
)
2107+
raise ValueError(msg)
2108+
21002109
# check that all the nodes have the same zarr_format as Self
21012110
for key, value in nodes.items():
21022111
if value.zarr_format != self.metadata.zarr_format:
@@ -2107,12 +2116,8 @@ def create_hierarchy(
21072116
)
21082117
raise ValueError(msg)
21092118
nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes))
2110-
if self.path == "":
2111-
root = "/"
2112-
else:
2113-
root = self.path
2114-
# TODO: make this safe against invalid path inputs
2115-
return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created}
2119+
for n in nodes_created:
2120+
yield (_join_paths([self.path, n.name]), n)
21162121

21172122
def keys(self) -> Generator[str, None]:
21182123
"""Return an iterator over group member names.
@@ -2884,8 +2889,12 @@ async def create_hierarchy(
28842889
The created nodes in the order they are created.
28852890
"""
28862891
nodes_parsed = _parse_hierarchy_dict(nodes)
2892+
28872893
if overwrite:
28882894
await store_path.delete_dir()
2895+
else:
2896+
# TODO: check if any of the nodes already exist, and error if so
2897+
raise NotImplementedError
28892898
async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore):
28902899
yield node
28912900

@@ -2897,60 +2906,74 @@ async def create_nodes(
28972906
semaphore: asyncio.Semaphore | None = None,
28982907
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
28992908
"""
2900-
Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity,
2909+
Create a collection of zarr arrays and groups concurrently and atomically. To ensure atomicity,
29012910
no attempt is made to ensure that intermediate groups are created.
29022911
"""
29032912
ctx: asyncio.Semaphore | contextlib.nullcontext[None]
2913+
29042914
if semaphore is None:
29052915
ctx = contextlib.nullcontext()
29062916
else:
29072917
ctx = semaphore
29082918

29092919
create_tasks: list[Coroutine[None, None, str]] = []
29102920
for key, value in nodes.items():
2911-
write_key = f"{store_path.path}/{key}".lstrip("/")
2912-
create_tasks.extend(_persist_metadata(store_path.store, write_key, value))
2921+
# transform the key, which is relative to a store_path.path, to a key in the store
2922+
write_prefix = _join_paths([store_path.path, key])
2923+
create_tasks.extend(_persist_metadata(store_path.store, write_prefix, value))
29132924

2914-
created_keys = []
2925+
created_object_keys = []
29152926
async with ctx:
29162927
for coro in asyncio.as_completed(create_tasks):
29172928
created_key = await coro
2918-
# the created key will be in the store key space. we have to remove the store_path.path
2929+
2930+
# the created key will be in the store key space, and it will end with the name of
2931+
# a metadata document.
2932+
# we have to remove the store_path.path
29192933
# component of that path to bring it back to the relative key space of store_path
29202934

2921-
relative_path = created_key.removeprefix(store_path.path).lstrip("/")
2922-
created_keys.append(relative_path)
2935+
# the relative path of the object we just created -- we need this to track which metadata documents
2936+
# were written so that we can yield a complete v2 Array / Group class after both .zattrs
2937+
# and the metadata JSON was created.
2938+
object_path_relative = created_key.removeprefix(store_path.path).lstrip("/")
2939+
created_object_keys.append(object_path_relative)
29232940

2924-
if len(relative_path.split("/")) == 1:
2941+
# get the node name from the object key
2942+
if len(object_path_relative.split("/")) == 1:
2943+
# this is the root node
2944+
meta_out = nodes[""]
29252945
node_name = ""
29262946
else:
2927-
node_name = "/".join(["", *relative_path.split("/")[:-1]])
2928-
2929-
meta_out = nodes[node_name]
2947+
# turn "foo/<anything>" into "foo"
2948+
node_name = object_path_relative[: object_path_relative.rfind("/")]
2949+
meta_out = nodes[node_name]
29302950

29312951
if meta_out.zarr_format == 3:
2952+
# yes, it is silly that we relativize, then de-relativize this same path
2953+
node_store_path = store_path / node_name
29322954
if isinstance(meta_out, GroupMetadata):
2933-
yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name)
2955+
yield AsyncGroup(metadata=meta_out, store_path=node_store_path)
29342956
else:
2935-
yield AsyncArray(metadata=meta_out, store_path=store_path / node_name)
2957+
yield AsyncArray(metadata=meta_out, store_path=node_store_path)
29362958
else:
29372959
# For zarr v2
29382960
# we only want to yield when both the metadata and attributes are created
29392961
# so we track which keys have been created, and wait for both the meta key and
29402962
# the attrs key to be created before yielding back the AsyncArray / AsyncGroup
29412963

2942-
attrs_done = f"{node_name}/.zattrs".lstrip("/") in created_keys
2964+
attrs_done = _join_paths([node_name, ZATTRS_JSON]) in created_object_keys
29432965

29442966
if isinstance(meta_out, GroupMetadata):
2945-
meta_done = f"{node_name}/.zgroup".lstrip("/") in created_keys
2967+
meta_done = _join_paths([node_name, ZGROUP_JSON]) in created_object_keys
29462968
else:
2947-
meta_done = f"{node_name}/.zarray".lstrip("/") in created_keys
2969+
meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys
29482970

29492971
if meta_done and attrs_done:
2972+
node_store_path = store_path / node_name
29502973
if isinstance(meta_out, GroupMetadata):
2951-
yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name)
2974+
yield AsyncGroup(metadata=meta_out, store_path=node_store_path)
29522975
else:
2953-
yield AsyncArray(metadata=meta_out, store_path=store_path / node_name)
2976+
yield AsyncArray(metadata=meta_out, store_path=node_store_path)
29542977
continue
29552978

29562979

@@ -2963,13 +2986,24 @@ def _get_roots(
29632986
"""
29642987
Return the keys of the root(s) of the hierarchy
29652988
"""
2989+
if "" in data:
2990+
return ("",)
29662991
keys_split = sorted((key.split("/") for key in data), key=len)
29672992
groups: defaultdict[int, list[str]] = defaultdict(list)
29682993
for key_split in keys_split:
29692994
groups[len(key_split)].append("/".join(key_split))
29702995
return tuple(groups[min(groups.keys())])
29712996

29722997

2998+
def _join_paths(paths: Iterable[str]) -> str:
2999+
"""
3000+
Filter out instances of '' and join the remaining strings with '/'.
3001+
3002+
Because the root node of a zarr hierarchy is represented by an empty string,
3003+
"""
3004+
return "/".join(filter(lambda v: v != "", paths))
3005+
3006+
29733007
def _parse_hierarchy_dict(
29743008
data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
29753009
) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]:
@@ -2993,7 +3027,7 @@ def _parse_hierarchy_dict(
29933027
# Create a copy of the input dict
29943028
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data}
29953029

2996-
observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []}
3030+
observed_zarr_formats: dict[ZarrFormat, list[str | None]] = {2: [], 3: []}
29973031

29983032
# We will iterate over the dict again, but a full pass here ensures that the error message
29993033
# is comprehensive, and I think the performance cost will be negligible.
@@ -3011,23 +3045,30 @@ def _parse_hierarchy_dict(
30113045
raise ValueError(msg)
30123046

30133047
for k, v in data.items():
3014-
# TODO: ensure that the key is a valid path
3015-
# Split the key into its path components
3016-
key_split = k.split("/")
3017-
3018-
# Iterate over the intermediate path components
3019-
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
3020-
for subpath in subpaths:
3021-
# If a component is not already in the output dict, add a group
3022-
if subpath not in out:
3023-
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
3024-
else:
3025-
if not isinstance(out[subpath], GroupMetadata):
3026-
msg = (
3027-
f"The node at {subpath} contains other nodes, but it is not a Zarr group. "
3028-
"This is invalid. Only Zarr groups can contain other nodes."
3029-
)
3030-
raise ValueError(msg)
3048+
if k is None:
3049+
# root node
3050+
pass
3051+
else:
3052+
if k.startswith("/"):
3053+
msg = f"Keys of hierarchy dicts must be relative paths, i.e. they cannot start with '/'. Got {k}, which violates this rule."
3054+
raise ValueError(k)
3055+
# TODO: ensure that the key is a valid path
3056+
# Split the key into its path components
3057+
key_split = k.split("/")
3058+
3059+
# Iterate over the intermediate path components
3060+
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
3061+
for subpath in subpaths:
3062+
# If a component is not already in the output dict, add a group
3063+
if subpath not in out:
3064+
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
3065+
else:
3066+
if not isinstance(out[subpath], GroupMetadata):
3067+
msg = (
3068+
f"The node at {subpath} contains other nodes, but it is not a Zarr group. "
3069+
"This is invalid. Only Zarr groups can contain other nodes."
3070+
)
3071+
raise ValueError(msg)
30313072

30323073
return out
30333074

@@ -3258,7 +3299,7 @@ def _persist_metadata(
32583299

32593300
to_save = metadata.to_buffer_dict(default_buffer_prototype())
32603301
return tuple(
3261-
_set_return_key(store=store, key=f"{path}/{key}".lstrip("/"), value=value, replace=True)
3302+
_set_return_key(store=store, key=_join_paths([path, key]), value=value, replace=True)
32623303
for key, value in to_save.items()
32633304
)
32643305

@@ -3278,7 +3319,7 @@ async def _from_flat(
32783319
"The input does not specify a root node. "
32793320
"This function can only create hierarchies that contain a root node, which is "
32803321
"defined as a group that is ancestral to all the other arrays and "
3281-
"groups in the hierarchy."
3322+
"groups in the hierarchy, or a single array."
32823323
)
32833324
raise ValueError(msg)
32843325
else:
@@ -3292,7 +3333,9 @@ async def _from_flat(
32923333
store_path=store_path, nodes=nodes, semaphore=semaphore, overwrite=overwrite
32933334
)
32943335
}
3295-
root_group = nodes_created[root]
3336+
# the names of the created nodes will be relative to the store_path instance
3337+
root_relative_to_store_path = _join_paths([store_path.path, root])
3338+
root_group = nodes_created[root_relative_to_store_path]
32963339
if not isinstance(root_group, AsyncGroup):
32973340
raise TypeError("Invalid root node returned from create_hierarchy.")
32983341
return root_group

tests/test_group.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConsolidatedMetadata,
2626
GroupMetadata,
2727
_from_flat,
28+
_join_paths,
2829
create_hierarchy,
2930
create_nodes,
3031
)
@@ -1492,7 +1493,7 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr
14921493
expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)}
14931494
spath = await make_store_path(store, path=path)
14941495
# initialize the group with some nodes
1495-
await _collect_aiterator(_from_flat(store_path=spath, nodes=pre_existing_nodes))
1496+
await _from_flat(store_path=spath, nodes=pre_existing_nodes)
14961497
observed_nodes = {
14971498
str(PurePosixPath(a.name).relative_to("/" + path)): a
14981499
async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite)
@@ -1501,7 +1502,8 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr
15011502

15021503

15031504
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1504-
def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat):
1505+
@pytest.mark.parametrize("overwrite", [True, False])
1506+
def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite: bool):
15051507
"""
15061508
Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict.
15071509
"""
@@ -1533,7 +1535,7 @@ def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_for
15331535

15341536
msg = "The zarr_format of the nodes must be the same as the parent group."
15351537
with pytest.raises(ValueError, match=msg):
1536-
_ = g.create_hierarchy(tree)
1538+
_ = tuple(g.create_hierarchy(tree))
15371539

15381540

15391541
@pytest.mark.parametrize("store", ["memory"], indirect=True)
@@ -1588,29 +1590,43 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store):
15881590
)
15891591

15901592

1591-
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1593+
@pytest.mark.parametrize("store", ["memory", "local"], indirect=True)
15921594
@pytest.mark.parametrize("zarr_format", [2, 3])
1593-
@pytest.mark.parametrize("root_key", ["", "a", "a/b"])
1595+
@pytest.mark.parametrize("root_key", ["", "root"])
15941596
@pytest.mark.parametrize("path", ["", "foo"])
15951597
async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: str):
15961598
"""
15971599
Test that the AsyncGroup.from_flat method creates a zarr group in one shot.
15981600
"""
15991601
spath = await make_store_path(store, path=path)
16001602
root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})}
1601-
members_expected_meta = {
1602-
f"{root_key}/b": GroupMetadata(
1603-
zarr_format=zarr_format, attributes={"path": f"{root_key}/b"}
1604-
),
1605-
f"{root_key}/b/c": GroupMetadata(
1606-
zarr_format=zarr_format, attributes={"path": f"{root_key}/b/c"}
1607-
),
1603+
group_names = ["a", "a/b"]
1604+
array_names = ["a/b/c", "a/b/d"]
1605+
1606+
# just to ensure that we don't use the same name twice in tests
1607+
assert set(group_names) & set(array_names) == set()
1608+
1609+
groups_expected_meta = {
1610+
_join_paths([root_key, node_name]): GroupMetadata(
1611+
zarr_format=zarr_format, attributes={"path": node_name}
1612+
)
1613+
for node_name in group_names
1614+
}
1615+
arrays_expected_meta = {
1616+
_join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format)
1617+
for node_name in array_names
16081618
}
1609-
g = await _from_flat(spath, nodes=root_meta | members_expected_meta)
1619+
1620+
nodes_create = root_meta | groups_expected_meta | arrays_expected_meta
1621+
1622+
g = await _from_flat(spath, nodes=nodes_create, overwrite=True)
1623+
assert g.metadata.attributes == {"path": root_key}
1624+
16101625
members = await _collect_aiterator(g.members(max_depth=None))
16111626
members_observed_meta = {k: v.metadata for k, v in members}
16121627
members_expected_meta_relative = {
1613-
str(PurePosixPath(k).relative_to(root_key)): v for k, v in members_expected_meta.items()
1628+
k.removeprefix(root_key).lstrip("/"): v
1629+
for k, v in (groups_expected_meta | arrays_expected_meta).items()
16141630
}
16151631
assert members_observed_meta == members_expected_meta_relative
16161632

0 commit comments

Comments
 (0)