Skip to content

Commit e2cff8c

Browse files
committed
group-level create_hierarchy
1 parent cf72834 commit e2cff8c

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/zarr/core/group.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclasses import asdict, dataclass, field, fields, replace
1010
from functools import partial
1111
from itertools import accumulate
12+
from pathlib import PurePosixPath
1213
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
1314

1415
import numpy as np
@@ -1424,6 +1425,33 @@ async def _members(
14241425
):
14251426
yield member
14261427

1428+
# TODO: find a better name for this. create_tree could work.
1429+
# TODO: include an example in the docstring
1430+
async def create_hierarchy(
1431+
self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata]
1432+
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
1433+
"""
1434+
Create a hierarchy of arrays or groups rooted at this group.
1435+
1436+
This method takes a dictionary where the keys are the names of the arrays or groups
1437+
to create and the values are the metadata or objects representing the arrays or groups.
1438+
1439+
The method returns an asynchronous iterator over the created nodes.
1440+
1441+
Parameters
1442+
----------
1443+
nodes : A dictionary representing the hierarchy to create
1444+
1445+
Returns
1446+
-------
1447+
An asynchronous iterator over the created nodes.
1448+
"""
1449+
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
1450+
async for node in create_hierarchy(
1451+
store_path=self.store_path, nodes=nodes, semaphore=semaphore
1452+
):
1453+
yield node
1454+
14271455
async def keys(self) -> AsyncGenerator[str, None]:
14281456
"""Iterate over member names."""
14291457
async for key, _ in self.members():
@@ -2046,6 +2074,32 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
20462074

20472075
return tuple((kv[0], _parse_async_node(kv[1])) for kv in _members)
20482076

2077+
def create_hierarchy(
2078+
self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata]
2079+
) -> dict[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
2080+
"""
2081+
Create a hierarchy of arrays or groups rooted at this group.
2082+
2083+
This method takes a dictionary where the keys are the names of the arrays or groups
2084+
to create and the values are the metadata objects for the arrays or groups.
2085+
2086+
The method returns an asynchronous iterator over the created nodes.
2087+
2088+
Parameters
2089+
----------
2090+
nodes : A dictionary representing the hierarchy to create
2091+
2092+
Returns
2093+
-------
2094+
A dict containing the created nodes.The keys are the same as th
2095+
"""
2096+
nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes))
2097+
if self.path == "":
2098+
root = "/"
2099+
else:
2100+
root = self.path
2101+
return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created}
2102+
20492103
def keys(self) -> Generator[str, None]:
20502104
"""Return an iterator over group member names.
20512105

tests/test_group.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,24 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None:
14881488
assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()}
14891489

14901490

1491+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1492+
def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat):
1493+
"""
1494+
Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict.
1495+
"""
1496+
g = Group.from_store(store)
1497+
tree = {
1498+
"a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}),
1499+
"a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}),
1500+
"a/b/c": meta_from_array(
1501+
np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"}
1502+
),
1503+
}
1504+
nodes = g.create_hierarchy(tree)
1505+
for k, v in nodes.items():
1506+
assert v.metadata == tree[k]
1507+
1508+
14911509
def test_group_members_performance(store: MemoryStore) -> None:
14921510
"""
14931511
Test that the execution time of Group.members is less than the number of members times the

0 commit comments

Comments
 (0)