Skip to content

Commit 57ceb64

Browse files
committed
tests and proper implementation for create_nodes and create_hierarchy
1 parent b6bf2dd commit 57ceb64

File tree

4 files changed

+338
-241
lines changed

4 files changed

+338
-241
lines changed

src/zarr/core/group.py

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import logging
77
import warnings
88
from collections import defaultdict
9-
from collections.abc import AsyncIterator, Awaitable
109
from dataclasses import asdict, dataclass, field, fields, replace
1110
from functools import partial
11+
from itertools import accumulate
1212
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
1313

1414
import numpy as np
@@ -50,13 +50,20 @@
5050
from zarr.core.config import config
5151
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
5252
from zarr.core.metadata.v3 import V3JsonEncoder
53-
from zarr.core.sync import SyncMixin, sync
53+
from zarr.core.sync import SyncMixin, _with_semaphore, sync
5454
from zarr.errors import MetadataValidationError
5555
from zarr.storage import StoreLike, StorePath, make_store_path
5656
from zarr.storage._common import ensure_no_existing_node
5757

5858
if TYPE_CHECKING:
59-
from collections.abc import AsyncGenerator, Callable, Generator, Iterable, Iterator
59+
from collections.abc import (
60+
AsyncGenerator,
61+
AsyncIterator,
62+
Generator,
63+
Iterable,
64+
Iterator,
65+
Mapping,
66+
)
6067
from typing import Any
6168

6269
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike
@@ -1265,36 +1272,14 @@ async def require_array(
12651272

12661273
return ds
12671274

1268-
async def create_nodes(
1275+
async def _create_nodes(
12691276
self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]
1270-
) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]]:
1277+
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
12711278
"""
12721279
Create a set of arrays or groups rooted at this group.
12731280
"""
1274-
_nodes: (
1275-
dict[str, GroupMetadata | ArrayV3Metadata] | dict[str, GroupMetadata | ArrayV2Metadata]
1276-
)
1277-
match self.metadata.zarr_format:
1278-
case 2:
1279-
if not all(
1280-
isinstance(node, ArrayV2Metadata | GroupMetadata) for node in nodes.values()
1281-
):
1282-
raise ValueError("Only v2 arrays and groups are supported")
1283-
_nodes = cast(dict[str, ArrayV2Metadata | GroupMetadata], nodes)
1284-
return await create_nodes_v2(
1285-
store=self.store_path.store, path=self.path, nodes=_nodes
1286-
)
1287-
case 3:
1288-
if not all(
1289-
isinstance(node, ArrayV3Metadata | GroupMetadata) for node in nodes.values()
1290-
):
1291-
raise ValueError("Only v3 arrays and groups are supported")
1292-
_nodes = cast(dict[str, ArrayV3Metadata | GroupMetadata], nodes)
1293-
return await create_nodes_v3(
1294-
store=self.store_path.store, path=self.path, nodes=_nodes
1295-
)
1296-
case _:
1297-
raise ValueError(f"Unsupported zarr format: {self.metadata.zarr_format}")
1281+
async for node in create_hierarchy(store_path=self.store_path, nodes=nodes):
1282+
yield node
12981283

12991284
async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
13001285
"""Update group attributes.
@@ -2818,15 +2803,6 @@ def array(
28182803
)
28192804

28202805

2821-
async def _with_semaphore(
2822-
func: Callable[[Any], Awaitable[T]], semaphore: asyncio.Semaphore | None = None
2823-
) -> T:
2824-
if semaphore is None:
2825-
return await func(None)
2826-
async with semaphore:
2827-
return await func(None)
2828-
2829-
28302806
async def _save_metadata(
28312807
node: AsyncArray[Any] | AsyncGroup,
28322808
) -> AsyncArray[Any] | AsyncGroup:
@@ -2843,6 +2819,43 @@ async def _save_metadata(
28432819
return node
28442820

28452821

2822+
async def create_hierarchy(
2823+
*,
2824+
store_path: StorePath,
2825+
nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata],
2826+
semaphore: asyncio.Semaphore | None = None,
2827+
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
2828+
"""
2829+
Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
2830+
``nodes`` will be created as needed.
2831+
2832+
This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy
2833+
concurrently. The groups and arrays in the hierarchy are created in a single pass, and the
2834+
function yields the created nodes in the order they are created.
2835+
2836+
Parameters
2837+
----------
2838+
store_path : StorePath
2839+
The StorePath object pointing to the root of the hierarchy.
2840+
nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
2841+
A dictionary defining the hierarchy. The keys are the paths of the nodes
2842+
in the hierarchy, and the values are the metadata of the nodes. The
2843+
metadata must be either an instance of GroupMetadata, ArrayV3Metadata
2844+
or ArrayV2Metadata.
2845+
semaphore : asyncio.Semaphore | None
2846+
An optional semaphore to limit the number of concurrent tasks. If not
2847+
provided, the number of concurrent tasks is not limited.
2848+
2849+
Yields
2850+
------
2851+
AsyncGroup | AsyncArray
2852+
The created nodes in the order they are created.
2853+
"""
2854+
nodes_parsed = parse_hierarchy_dict(nodes)
2855+
async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore):
2856+
yield node
2857+
2858+
28462859
async def create_nodes(
28472860
*,
28482861
store_path: StorePath,
@@ -2875,28 +2888,28 @@ async def create_nodes(
28752888
T = TypeVar("T")
28762889

28772890

2878-
def _split_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]:
2891+
def parse_hierarchy_dict(
2892+
data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
2893+
) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]:
28792894
"""
2880-
Given a dict of {string: T} pairs, where the keys are strings separated by some separator,
2881-
return the result of splitting each key with the separator.
2882-
2883-
Parameters
2884-
----------
2885-
data : dict[str, T]
2886-
A dict of {string:, T} pairs.
2887-
2888-
Returns
2889-
-------
2890-
dict[tuple[str,...], T]
2891-
The same values, but the keys have been split and converted to tuples.
2895+
If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups,
2896+
then return an identical copy of that dict. Otherwise, return a version of the input dict
2897+
with groups added where they are needed to make the hierarchy explicit.
28922898
2893-
Examples
2894-
--------
2895-
>>> _split_keys({"a": 1}, separator='/')
2896-
{("a",): 1}
2899+
For example, an input of {'a/b/c': ...} will result in a return value of
2900+
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
28972901
2898-
>>> _split_keys({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/')
2899-
{("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3}
2902+
This function is useful for ensuring that the input to create_hierarchy is a complete
2903+
Zarr hierarchy.
29002904
"""
2901-
2902-
return {tuple(k.split(separator)): v for k, v in data.items()}
2905+
# Create a copy of the input dict
2906+
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data}
2907+
for k, v in data.items():
2908+
# Split the key into its path components
2909+
key_split = k.split("/")
2910+
# Iterate over the path components
2911+
for subpath in accumulate(key_split, lambda a, b: f"{a}/{b}"):
2912+
# If a component is not already in the output dict, add it
2913+
if subpath not in out:
2914+
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
2915+
return out

src/zarr/core/sync.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import logging
66
import threading
77
from concurrent.futures import ThreadPoolExecutor, wait
8-
from typing import TYPE_CHECKING, TypeVar
8+
from typing import TYPE_CHECKING, Any, TypeVar
99

1010
from typing_extensions import ParamSpec
1111

1212
from zarr.core.config import config
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import AsyncIterator, Coroutine
15+
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
1616
from typing import Any
1717

1818
logger = logging.getLogger(__name__)
@@ -192,3 +192,17 @@ async def iter_to_list() -> list[T]:
192192
return [item async for item in async_iterator]
193193

194194
return self._sync(iter_to_list())
195+
196+
197+
async def _with_semaphore(
198+
func: Callable[[], Awaitable[T]], semaphore: asyncio.Semaphore | None = None
199+
) -> T:
200+
"""
201+
Await the result of invoking the no-argument-callable ``func`` within the context manager
202+
provided by a Semaphore, if one is provided. Otherwise, just await the result of invoking
203+
``func``.
204+
"""
205+
if semaphore is None:
206+
return await func()
207+
async with semaphore:
208+
return await func()

0 commit comments

Comments
 (0)