66import logging
77import warnings
88from collections import defaultdict
9- from collections .abc import AsyncIterator , Awaitable
109from dataclasses import asdict , dataclass , field , fields , replace
1110from functools import partial
11+ from itertools import accumulate
1212from typing import TYPE_CHECKING , Literal , TypeVar , assert_never , cast , overload
1313
1414import numpy as np
5050from zarr .core .config import config
5151from zarr .core .metadata import ArrayV2Metadata , ArrayV3Metadata
5252from zarr .core .metadata .v3 import V3JsonEncoder
53- from zarr .core .sync import SyncMixin , sync
53+ from zarr .core .sync import SyncMixin , _with_semaphore , sync
5454from zarr .errors import MetadataValidationError
5555from zarr .storage import StoreLike , StorePath , make_store_path
5656from zarr .storage ._common import ensure_no_existing_node
5757
5858if TYPE_CHECKING :
59- from collections .abc import AsyncGenerator , Callable , Generator , Iterable , Iterator
59+ from collections .abc import (
60+ AsyncGenerator ,
61+ AsyncIterator ,
62+ Generator ,
63+ Iterable ,
64+ Iterator ,
65+ Mapping ,
66+ )
6067 from typing import Any
6168
6269 from zarr .core .array_spec import ArrayConfig , ArrayConfigLike
@@ -1265,36 +1272,14 @@ async def require_array(
12651272
12661273 return ds
12671274
1268- async def create_nodes (
1275+ async def _create_nodes (
12691276 self , nodes : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ]
1270- ) -> tuple [ tuple [ str , AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ] ]]:
1277+ ) -> AsyncIterator [ AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]:
12711278 """
12721279 Create a set of arrays or groups rooted at this group.
12731280 """
1274- _nodes : (
1275- dict [str , GroupMetadata | ArrayV3Metadata ] | dict [str , GroupMetadata | ArrayV2Metadata ]
1276- )
1277- match self .metadata .zarr_format :
1278- case 2 :
1279- if not all (
1280- isinstance (node , ArrayV2Metadata | GroupMetadata ) for node in nodes .values ()
1281- ):
1282- raise ValueError ("Only v2 arrays and groups are supported" )
1283- _nodes = cast (dict [str , ArrayV2Metadata | GroupMetadata ], nodes )
1284- return await create_nodes_v2 (
1285- store = self .store_path .store , path = self .path , nodes = _nodes
1286- )
1287- case 3 :
1288- if not all (
1289- isinstance (node , ArrayV3Metadata | GroupMetadata ) for node in nodes .values ()
1290- ):
1291- raise ValueError ("Only v3 arrays and groups are supported" )
1292- _nodes = cast (dict [str , ArrayV3Metadata | GroupMetadata ], nodes )
1293- return await create_nodes_v3 (
1294- store = self .store_path .store , path = self .path , nodes = _nodes
1295- )
1296- case _:
1297- raise ValueError (f"Unsupported zarr format: { self .metadata .zarr_format } " )
1281+ async for node in create_hierarchy (store_path = self .store_path , nodes = nodes ):
1282+ yield node
12981283
12991284 async def update_attributes (self , new_attributes : dict [str , Any ]) -> AsyncGroup :
13001285 """Update group attributes.
@@ -2818,15 +2803,6 @@ def array(
28182803 )
28192804
28202805
2821- async def _with_semaphore (
2822- func : Callable [[Any ], Awaitable [T ]], semaphore : asyncio .Semaphore | None = None
2823- ) -> T :
2824- if semaphore is None :
2825- return await func (None )
2826- async with semaphore :
2827- return await func (None )
2828-
2829-
28302806async def _save_metadata (
28312807 node : AsyncArray [Any ] | AsyncGroup ,
28322808) -> AsyncArray [Any ] | AsyncGroup :
@@ -2843,6 +2819,43 @@ async def _save_metadata(
28432819 return node
28442820
28452821
2822+ async def create_hierarchy (
2823+ * ,
2824+ store_path : StorePath ,
2825+ nodes : dict [str , GroupMetadata | ArrayV3Metadata | ArrayV2Metadata ],
2826+ semaphore : asyncio .Semaphore | None = None ,
2827+ ) -> AsyncIterator [AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]:
2828+ """
2829+ Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
2830+ ``nodes`` will be created as needed.
2831+
2832+ This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy
2833+ concurrently. The groups and arrays in the hierarchy are created in a single pass, and the
2834+ function yields the created nodes in the order they are created.
2835+
2836+ Parameters
2837+ ----------
2838+ store_path : StorePath
2839+ The StorePath object pointing to the root of the hierarchy.
2840+ nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
2841+ A dictionary defining the hierarchy. The keys are the paths of the nodes
2842+ in the hierarchy, and the values are the metadata of the nodes. The
2843+ metadata must be either an instance of GroupMetadata, ArrayV3Metadata
2844+ or ArrayV2Metadata.
2845+ semaphore : asyncio.Semaphore | None
2846+ An optional semaphore to limit the number of concurrent tasks. If not
2847+ provided, the number of concurrent tasks is not limited.
2848+
2849+ Yields
2850+ ------
2851+ AsyncGroup | AsyncArray
2852+ The created nodes in the order they are created.
2853+ """
2854+ nodes_parsed = parse_hierarchy_dict (nodes )
2855+ async for node in create_nodes (store_path = store_path , nodes = nodes_parsed , semaphore = semaphore ):
2856+ yield node
2857+
2858+
28462859async def create_nodes (
28472860 * ,
28482861 store_path : StorePath ,
@@ -2875,28 +2888,28 @@ async def create_nodes(
28752888T = TypeVar ("T" )
28762889
28772890
2878- def _split_keys (data : dict [str , T ], separator : str ) -> dict [tuple [str , ...], T ]:
2891+ def parse_hierarchy_dict (
2892+ data : Mapping [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
2893+ ) -> dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ]:
28792894 """
2880- Given a dict of {string: T} pairs, where the keys are strings separated by some separator,
2881- return the result of splitting each key with the separator.
2882-
2883- Parameters
2884- ----------
2885- data : dict[str, T]
2886- A dict of {string:, T} pairs.
2887-
2888- Returns
2889- -------
2890- dict[tuple[str,...], T]
2891- The same values, but the keys have been split and converted to tuples.
2895+ If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups,
2896+ then return an identical copy of that dict. Otherwise, return a version of the input dict
2897+ with groups added where they are needed to make the hierarchy explicit.
28922898
2893- Examples
2894- --------
2895- >>> _split_keys({"a": 1}, separator='/')
2896- {("a",): 1}
2899+ For example, an input of {'a/b/c': ...} will result in a return value of
2900+ {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
28972901
2898- >>> _split_keys({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/')
2899- {("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3}
2902+ This function is useful for ensuring that the input to create_hierarchy is a complete
2903+ Zarr hierarchy.
29002904 """
2901-
2902- return {tuple (k .split (separator )): v for k , v in data .items ()}
2905+ # Create a copy of the input dict
2906+ out : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ] = {** data }
2907+ for k , v in data .items ():
2908+ # Split the key into its path components
2909+ key_split = k .split ("/" )
2910+ # Iterate over the path components
2911+ for subpath in accumulate (key_split , lambda a , b : f"{ a } /{ b } " ):
2912+ # If a component is not already in the output dict, add it
2913+ if subpath not in out :
2914+ out [subpath ] = GroupMetadata (zarr_format = v .zarr_format )
2915+ return out
0 commit comments