Skip to content

Commit 8245e80

Browse files
committed
fix group.create_hierarchy to properly prefix keys with the name of the group
1 parent 7c56b87 commit 8245e80

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/zarr/core/group.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,9 +1414,11 @@ async def create_hierarchy(
14141414
14151415
Returns
14161416
-------
1417-
An asynchronous iterator over the created arrays and / or groups.
1417+
An asynchronous iterator of (str, AsyncArray | AsyncGroup) pairs.
14181418
"""
14191419
# check that all the nodes have the same zarr_format as Self
1420+
prefix = self.path
1421+
nodes_parsed = {}
14201422
for key, value in nodes.items():
14211423
if value.zarr_format != self.metadata.zarr_format:
14221424
msg = (
@@ -1433,15 +1435,19 @@ async def create_hierarchy(
14331435
"create_rooted_hierarchy to create a rooted hierarchy."
14341436
)
14351437
raise ValueError(msg)
1436-
1437-
nodes_rooted = nodes
1438+
else:
1439+
nodes_parsed[_join_paths([prefix, key])] = value
14381440

14391441
async for key, node in create_hierarchy(
14401442
store=self.store,
1441-
nodes=nodes_rooted,
1443+
nodes=nodes_parsed,
14421444
overwrite=overwrite,
14431445
):
1444-
yield key, node
1446+
if prefix == "":
1447+
out_key = key
1448+
else:
1449+
out_key = key.removeprefix(prefix + "/")
1450+
yield out_key, node
14451451

14461452
async def keys(self) -> AsyncGenerator[str, None]:
14471453
"""Iterate over member names."""

tests/test_group.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,17 +1645,24 @@ async def test_create_hierarchy_existing_nodes(
16451645

16461646
@pytest.mark.parametrize("store", ["memory"], indirect=True)
16471647
@pytest.mark.parametrize("overwrite", [True, False])
1648+
@pytest.mark.parametrize("group_path", ["", "foo"])
16481649
@pytest.mark.parametrize("impl", ["async", "sync"])
16491650
async def test_group_create_hierarchy(
1650-
store: Store, zarr_format: ZarrFormat, overwrite: bool, impl: Literal["async", "sync"]
1651+
store: Store,
1652+
zarr_format: ZarrFormat,
1653+
overwrite: bool,
1654+
group_path: str,
1655+
impl: Literal["async", "sync"],
16511656
) -> None:
16521657
"""
16531658
Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict.
16541659
Also test that off-target nodes are not deleted, and that the root group is not deleted
16551660
"""
16561661
root_attrs = {"root": True}
1657-
g = Group.from_store(store, zarr_format=zarr_format, attributes=root_attrs)
1658-
1662+
g = sync_group.create_rooted_hierarchy(
1663+
store=store,
1664+
nodes={group_path: GroupMetadata(zarr_format=zarr_format, attributes=root_attrs)},
1665+
)
16591666
node_spec = {
16601667
"a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}),
16611668
"a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}),
@@ -1689,7 +1696,10 @@ async def test_group_create_hierarchy(
16891696
else:
16901697
assert all_members[k].metadata == v == extant_created[k].metadata
16911698
# ensure that we left the root group as-is
1692-
assert sync_group.get_node(store=store, path="", zarr_format=zarr_format).attrs == root_attrs
1699+
assert (
1700+
sync_group.get_node(store=store, path=group_path, zarr_format=zarr_format).attrs.asdict()
1701+
== root_attrs
1702+
)
16931703

16941704

16951705
@pytest.mark.parametrize("store", ["memory"], indirect=True)

0 commit comments

Comments
 (0)