11from __future__ import annotations
22
33import asyncio
4+ import contextlib
45import itertools
56import json
67import logging
78import warnings
89from collections import defaultdict
910from dataclasses import asdict , dataclass , field , fields , replace
10- from functools import partial
1111from itertools import accumulate
1212from pathlib import PurePosixPath
13- from typing import TYPE_CHECKING , Literal , Self , TypeVar , assert_never , cast , overload
13+ from typing import (
14+ TYPE_CHECKING ,
15+ Literal ,
16+ Self ,
17+ TypeVar ,
18+ assert_never ,
19+ cast ,
20+ overload ,
21+ )
1422
1523import numpy as np
1624import numpy .typing as npt
5159from zarr .core .config import config
5260from zarr .core .metadata import ArrayV2Metadata , ArrayV3Metadata
5361from zarr .core .metadata .v3 import V3JsonEncoder
54- from zarr .core .sync import SyncMixin , _with_semaphore , sync
62+ from zarr .core .sync import SyncMixin , sync
5563from zarr .errors import MetadataValidationError
5664from zarr .storage import StoreLike , StorePath
5765from zarr .storage ._common import ensure_no_existing_node , make_store_path
6068 from collections .abc import (
6169 AsyncGenerator ,
6270 AsyncIterator ,
71+ Coroutine ,
6372 Generator ,
6473 Iterable ,
6574 Iterator ,
@@ -431,21 +440,32 @@ class AsyncGroup:
431440 @classmethod
432441 async def from_flat (
433442 cls ,
434- store : StoreLike ,
435- * ,
436- nodes : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
437- overwrite : bool = False ) -> Self :
438-
443+ store : StoreLike ,
444+ * ,
445+ nodes : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
446+ overwrite : bool = False ,
447+ ) -> Self :
448+ if not _is_rooted (nodes ):
449+ msg = (
450+ "The input does not specify a root node. " ,
451+ "This function can only create hierarchies that contain a root node, which is " ,
452+ "defined as a group that is ancestral to all the other arrays and " ,
453+ "groups in the hierarchy." ,
454+ )
455+ raise ValueError (msg )
456+
439457 if overwrite :
440- store_path = await make_store_path (store , mode = 'w' )
458+ store_path = await make_store_path (store , mode = "w" )
441459 else :
442- store_path = await make_store_path (store , mode = 'w-' )
460+ store_path = await make_store_path (store , mode = "w-" )
461+
443462 semaphore = asyncio .Semaphore (config .get ("async.concurrency" ))
444-
445- nodes_created = {x .name : x async for x in create_hierarchy (
446- store_path = store_path , nodes = nodes , semaphore = semaphore
447- )}
448- return nodes_created ['' ]
463+
464+ nodes_created = {
465+ x .name : x
466+ async for x in create_hierarchy (store_path = store_path , nodes = nodes , semaphore = semaphore )
467+ }
468+ # TODO: make this work
449469
450470 @classmethod
451471 async def from_store (
@@ -1743,17 +1763,18 @@ async def move(self, source: str, dest: str) -> None:
17431763@dataclass (frozen = True )
17441764class Group (SyncMixin ):
17451765 _async_group : AsyncGroup
1746-
1766+
17471767 @classmethod
17481768 def from_flat (
1749- cls ,
1769+ cls ,
17501770 store : StoreLike ,
1751- * ,
1771+ * ,
17521772 nodes : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
1753- overwrite : bool = False ) -> Group :
1773+ overwrite : bool = False ,
1774+ ) -> Group :
17541775 nodes = sync (AsyncGroup .from_flat (store , nodes = nodes , overwrite = overwrite ))
17551776 # return the root node of the hierarchy
1756- return nodes ['' ]
1777+ return nodes ["" ]
17571778
17581779 @classmethod
17591780 def from_store (
@@ -2110,17 +2131,28 @@ def create_hierarchy(
21102131
21112132 Parameters
21122133 ----------
2113- nodes : A dictionary representing the hierarchy to create
2134+ nodes : A dictionary representing the hierarchy to create. The keys should be relative paths
2135+ and the values should be the metadata for the arrays or groups to create.
21142136
21152137 Returns
21162138 -------
2117- A dict containing the created nodes.The keys are the same as th
2118- """
2139+ A dict containing the created nodes, with the same keys as the input
2140+ """
2141+ # check that all the nodes have the same zarr_format as Self
2142+ for key , value in nodes .items ():
2143+ if value .zarr_format != self .metadata .zarr_format :
2144+ msg = (
2145+ "The zarr_format of the nodes must be the same as the parent group. "
2146+ f"The node at { key } has zarr_format { value .zarr_format } , but the parent group"
2147+ f" has zarr_format { self .metadata .zarr_format } ."
2148+ )
2149+ raise ValueError (msg )
21192150 nodes_created = self ._sync_iter (self ._async_group .create_hierarchy (nodes ))
21202151 if self .path == "" :
21212152 root = "/"
21222153 else :
21232154 root = self .path
2155+ # TODO: make this safe against invalid path inputs
21242156 return {str (PurePosixPath (n .name ).relative_to (root )): n for n in nodes_created }
21252157
21262158 def keys (self ) -> Generator [str , None ]:
@@ -2859,6 +2891,7 @@ def array(
28592891
28602892async def _save_metadata (
28612893 node : AsyncArray [Any ] | AsyncGroup ,
2894+ overwrite : bool ,
28622895) -> AsyncArray [Any ] | AsyncGroup :
28632896 """
28642897 Save the metadata for an array or group, and return the array or group
@@ -2878,6 +2911,7 @@ async def create_hierarchy(
28782911 store_path : StorePath ,
28792912 nodes : dict [str , GroupMetadata | ArrayV3Metadata | ArrayV2Metadata ],
28802913 semaphore : asyncio .Semaphore | None = None ,
2914+ overwrite : bool = False ,
28812915) -> AsyncIterator [AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]:
28822916 """
28832917 Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
@@ -2906,42 +2940,81 @@ async def create_hierarchy(
29062940 The created nodes in the order they are created.
29072941 """
29082942 nodes_parsed = _parse_hierarchy_dict (nodes )
2943+
29092944 async for node in create_nodes (store_path = store_path , nodes = nodes_parsed , semaphore = semaphore ):
29102945 yield node
29112946
29122947
29132948async def create_nodes (
29142949 * ,
29152950 store_path : StorePath ,
2916- nodes : dict [str , GroupMetadata | ArrayV3Metadata | ArrayV2Metadata ],
2951+ nodes : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
29172952 semaphore : asyncio .Semaphore | None = None ,
2918- ) -> AsyncIterator [AsyncGroup | AsyncArray [Any ]]:
2953+ ) -> AsyncIterator [AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ ArrayV3Metadata ]]:
29192954 """
2920- Create a collection of arrays and groups concurrently and atomically. To ensure atomicity,
2955+ Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity,
29212956 no attempt is made to ensure that intermediate groups are created.
29222957 """
2923- create_tasks = []
2958+ ctx : asyncio .Semaphore | contextlib .nullcontext [None ]
2959+ if semaphore is None :
2960+ ctx = contextlib .nullcontext ()
2961+ else :
2962+ ctx = semaphore
2963+
2964+ create_tasks : list [Coroutine [None , None , str ]] = []
2965+
29242966 for key , value in nodes .items ():
2925- new_store_path = store_path / key
2926- node : AsyncArray [Any ] | AsyncGroup
2927- match value :
2928- case ArrayV3Metadata () | ArrayV2Metadata ():
2929- node = AsyncArray (value , store_path = new_store_path )
2930- case GroupMetadata ():
2931- node = AsyncGroup (value , store_path = new_store_path )
2932- case _:
2933- raise ValueError (f"Unexpected metadata type { type (value )} " )
2934- partial_func = partial (_save_metadata , node )
2935- fut = _with_semaphore (partial_func , semaphore )
2936- create_tasks .append (fut )
2967+ create_tasks .extend (
2968+ _prepare_save_metadata (store_path .store , f"{ store_path .path } /{ key } " , value )
2969+ )
2970+
2971+ created_keys = []
2972+ async with ctx :
2973+ for coro in asyncio .as_completed (create_tasks ):
2974+ created_key = await coro
2975+ relative_path = PurePosixPath (created_key ).relative_to (store_path .path )
2976+ created_keys .append (str (relative_path ))
2977+ # convert /foo/bar/baz/.zattrs to bar/baz
2978+ node_name = str (relative_path .parent )
2979+ meta_out = nodes [node_name ]
2980+
2981+ if meta_out .zarr_format == 3 :
2982+ if isinstance (meta_out , GroupMetadata ):
2983+ yield AsyncGroup (metadata = meta_out , store_path = store_path / node_name )
2984+ else :
2985+ yield AsyncArray (metadata = meta_out , store_path = store_path / node_name )
2986+ else :
2987+ # For zarr v2
2988+ # we only want to yield when both the metadata and attributes are created
2989+ # so we track which keys have been created, and wait for both the meta key and
2990+ # the attrs key to be created before yielding back the AsyncArray / AsyncGroup
2991+
2992+ attrs_done = f"{ node_name } /.zattrs" in created_keys
2993+
2994+ if isinstance (meta_out , GroupMetadata ):
2995+ meta_done = f"{ node_name } /.zgroup" in created_keys
2996+ else :
2997+ meta_done = f"{ node_name } /.zarray" in created_keys
29372998
2938- for coro in asyncio .as_completed (create_tasks ):
2939- yield await coro
2999+ if meta_done and attrs_done :
3000+ if isinstance (meta_out , GroupMetadata ):
3001+ yield AsyncGroup (metadata = meta_out , store_path = store_path / node_name )
3002+ else :
3003+ yield AsyncArray (metadata = meta_out , store_path = store_path / node_name )
29403004
29413005
29423006T = TypeVar ("T" )
29433007
29443008
3009+ def _is_rooted (data : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ]) -> bool :
3010+ """
3011+ Check if the data describes a hierarchy that's rooted, which means there is a single node with
3012+ the least number of components in its key
3013+ """
3014+ # a dict
3015+ return False
3016+
3017+
29453018def _parse_hierarchy_dict (
29463019 data : Mapping [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
29473020) -> dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ]:
@@ -2953,19 +3026,54 @@ def _parse_hierarchy_dict(
29533026 For example, an input of {'a/b/c': ...} will result in a return value of
29543027 {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
29553028
2956- This function is useful for ensuring that the input to create_hierarchy is a complete
3029+ The input is also checked for the following conditions, and an error is raised if any
3030+ of them are violated:
3031+
3032+ - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes).
3033+ - All arrays and groups must have the same ``zarr_format`` value.
3034+
3035+ This function ensures that the input is transformed into a specification of a complete and valid
29573036 Zarr hierarchy.
29583037 """
29593038 # Create a copy of the input dict
29603039 out : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ] = {** data }
3040+
3041+ observed_zarr_formats : dict [ZarrFormat , list [str ]] = {2 : [], 3 : []}
3042+
3043+ # We will iterate over the dict again, but a full pass here ensures that the error message
3044+ # is comprehensive, and I think the performance cost will be negligible.
3045+ for k , v in data .items ():
3046+ observed_zarr_formats [v .zarr_format ].append (k )
3047+
3048+ if len (observed_zarr_formats [2 ]) > 0 and len (observed_zarr_formats [3 ]) > 0 :
3049+ msg = (
3050+ "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. "
3051+ f"The following keys map to Zarr v2 nodes: { observed_zarr_formats .get (2 )} . "
3052+ f"The following keys map to Zarr v3 nodes: { observed_zarr_formats .get (3 )} ."
3053+ "Ensure that all nodes have the same Zarr format."
3054+ )
3055+
3056+ raise ValueError (msg )
3057+
29613058 for k , v in data .items ():
3059+ # TODO: ensure that the key is a valid path
29623060 # Split the key into its path components
29633061 key_split = k .split ("/" )
2964- # Iterate over the path components
2965- for subpath in accumulate (key_split , lambda a , b : f"{ a } /{ b } " ):
3062+
3063+ # Iterate over the intermediate path components
3064+ * subpaths , _ = accumulate (key_split , lambda a , b : f"{ a } /{ b } " )
3065+ for subpath in subpaths :
29663066 # If a component is not already in the output dict, add it
29673067 if subpath not in out :
29683068 out [subpath ] = GroupMetadata (zarr_format = v .zarr_format )
3069+ else :
3070+ if not isinstance (out [subpath ], GroupMetadata ):
3071+ msg = (
3072+ f"The node at { subpath } contains other nodes, but it is not a Zarr group. "
3073+ "This is invalid. Only Zarr groups can contain other nodes."
3074+ )
3075+ raise ValueError (msg )
3076+
29693077 return out
29703078
29713079
@@ -3155,3 +3263,24 @@ def _build_node_v2(
31553263 return AsyncGroup (metadata , store_path = store_path )
31563264 case _:
31573265 raise ValueError (f"Unexpected metadata type: { type (metadata )} " )
3266+
3267+
3268+ async def _set_return_key (store : Store , key : str , value : Buffer ) -> str :
3269+ """
3270+ Store.set, but the key and the value are returned.
3271+ Useful when saving metadata via asyncio.as_completed, because
3272+ we need to know which key was saved.
3273+ """
3274+ await store .set (key , value )
3275+ return key
3276+
3277+
3278+ def _prepare_save_metadata (
3279+ store : Store , path : str , metadata : ArrayV2Metadata | ArrayV3Metadata | GroupMetadata
3280+ ) -> tuple [Coroutine [None , None , str ], ...]:
3281+ """
3282+ Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited.
3283+ """
3284+
3285+ to_save = metadata .to_buffer_dict (default_buffer_prototype ())
3286+ return tuple (_set_return_key (store , f"{ path } /{ key } " , value ) for key , value in to_save .items ())
0 commit comments