99from collections import defaultdict
1010from dataclasses import asdict , dataclass , field , fields , replace
1111from itertools import accumulate
12- from pathlib import PurePosixPath
1312from typing import (
1413 TYPE_CHECKING ,
1514 Literal ,
@@ -2079,7 +2078,9 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
20792078
20802079 def create_hierarchy (
20812080 self , nodes : dict [str , ArrayV2Metadata | ArrayV3Metadata | GroupMetadata ]
2082- ) -> dict [str , AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]:
2081+ ) -> Iterator [
2082+ tuple [str , AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]
2083+ ]:
20832084 """
20842085 Create a hierarchy of arrays or groups rooted at this group.
20852086
@@ -2097,6 +2098,14 @@ def create_hierarchy(
20972098 -------
20982099 A dict containing the created nodes, with the same keys as the input
20992100 """
2101+ if "" in nodes :
2102+ msg = (
2103+ "Found the key '' in nodes, which denotes the root group. Creating the root group "
2104+ "from an existing group is not supported. If you want to create an entire Zarr group, "
2105+ "including the root group, from a dict then use the _from_flat method."
2106+ )
2107+ raise ValueError (msg )
2108+
21002109 # check that all the nodes have the same zarr_format as Self
21012110 for key , value in nodes .items ():
21022111 if value .zarr_format != self .metadata .zarr_format :
@@ -2107,12 +2116,8 @@ def create_hierarchy(
21072116 )
21082117 raise ValueError (msg )
21092118 nodes_created = self ._sync_iter (self ._async_group .create_hierarchy (nodes ))
2110- if self .path == "" :
2111- root = "/"
2112- else :
2113- root = self .path
2114- # TODO: make this safe against invalid path inputs
2115- return {str (PurePosixPath (n .name ).relative_to (root )): n for n in nodes_created }
2119+ for n in nodes_created :
2120+ yield (_join_paths ([self .path , n .name ]), n )
21162121
21172122 def keys (self ) -> Generator [str , None ]:
21182123 """Return an iterator over group member names.
@@ -2884,8 +2889,12 @@ async def create_hierarchy(
28842889 The created nodes in the order they are created.
28852890 """
28862891 nodes_parsed = _parse_hierarchy_dict (nodes )
2892+
28872893 if overwrite :
28882894 await store_path .delete_dir ()
2895+ else :
2896+ # TODO: check if any of the nodes already exist, and error if so
2897+ raise NotImplementedError
28892898 async for node in create_nodes (store_path = store_path , nodes = nodes_parsed , semaphore = semaphore ):
28902899 yield node
28912900
@@ -2897,60 +2906,74 @@ async def create_nodes(
28972906 semaphore : asyncio .Semaphore | None = None ,
28982907) -> AsyncIterator [AsyncGroup | AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ]]:
28992908 """
2900- Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity,
2909+ Create a collection of zarr arrays and groups concurrently and atomically. To ensure atomicity,
29012910 no attempt is made to ensure that intermediate groups are created.
29022911 """
29032912 ctx : asyncio .Semaphore | contextlib .nullcontext [None ]
2913+
29042914 if semaphore is None :
29052915 ctx = contextlib .nullcontext ()
29062916 else :
29072917 ctx = semaphore
29082918
29092919 create_tasks : list [Coroutine [None , None , str ]] = []
29102920 for key , value in nodes .items ():
2911- write_key = f"{ store_path .path } /{ key } " .lstrip ("/" )
2912- create_tasks .extend (_persist_metadata (store_path .store , write_key , value ))
2921+ # transform the key, which is relative to a store_path.path, to a key in the store
2922+ write_prefix = _join_paths ([store_path .path , key ])
2923+ create_tasks .extend (_persist_metadata (store_path .store , write_prefix , value ))
29132924
2914- created_keys = []
2925+ created_object_keys = []
29152926 async with ctx :
29162927 for coro in asyncio .as_completed (create_tasks ):
29172928 created_key = await coro
2918- # the created key will be in the store key space. we have to remove the store_path.path
2929+
2930+ # the created key will be in the store key space, and it will end with the name of
2931+ # a metadata document.
2932+ # we have to remove the store_path.path
29192933 # component of that path to bring it back to the relative key space of store_path
29202934
2921- relative_path = created_key .removeprefix (store_path .path ).lstrip ("/" )
2922- created_keys .append (relative_path )
2935+ # the relative path of the object we just created -- we need this to track which metadata documents
2936+ # were written so that we can yield a complete v2 Array / Group class after both .zattrs
2937+ # and the metadata JSON was created.
2938+ object_path_relative = created_key .removeprefix (store_path .path ).lstrip ("/" )
2939+ created_object_keys .append (object_path_relative )
29232940
2924- if len (relative_path .split ("/" )) == 1 :
2941+ # get the node name from the object key
2942+ if len (object_path_relative .split ("/" )) == 1 :
2943+ # this is the root node
2944+ meta_out = nodes ["" ]
29252945 node_name = ""
29262946 else :
2927- node_name = "/" . join ([ "" , * relative_path . split ( "/" )[: - 1 ]])
2928-
2929- meta_out = nodes [node_name ]
2947+ # turn "foo/<anything>" into "foo"
2948+ node_name = object_path_relative [: object_path_relative . rfind ( "/" )]
2949+ meta_out = nodes [node_name ]
29302950
29312951 if meta_out .zarr_format == 3 :
2952+ # yes, it is silly that we relativize, then de-relativize this same path
2953+ node_store_path = store_path / node_name
29322954 if isinstance (meta_out , GroupMetadata ):
2933- yield AsyncGroup (metadata = meta_out , store_path = store_path / node_name )
2955+ yield AsyncGroup (metadata = meta_out , store_path = node_store_path )
29342956 else :
2935- yield AsyncArray (metadata = meta_out , store_path = store_path / node_name )
2957+ yield AsyncArray (metadata = meta_out , store_path = node_store_path )
29362958 else :
29372959 # For zarr v2
29382960 # we only want to yield when both the metadata and attributes are created
29392961 # so we track which keys have been created, and wait for both the meta key and
29402962 # the attrs key to be created before yielding back the AsyncArray / AsyncGroup
29412963
2942- attrs_done = f" { node_name } /.zattrs" . lstrip ( "/" ) in created_keys
2964+ attrs_done = _join_paths ([ node_name , ZATTRS_JSON ] ) in created_object_keys
29432965
29442966 if isinstance (meta_out , GroupMetadata ):
2945- meta_done = f" { node_name } /.zgroup" . lstrip ( "/" ) in created_keys
2967+ meta_done = _join_paths ([ node_name , ZGROUP_JSON ] ) in created_object_keys
29462968 else :
2947- meta_done = f" { node_name } /.zarray" . lstrip ( "/" ) in created_keys
2969+ meta_done = _join_paths ([ node_name , ZARRAY_JSON ] ) in created_object_keys
29482970
29492971 if meta_done and attrs_done :
2972+ node_store_path = store_path / node_name
29502973 if isinstance (meta_out , GroupMetadata ):
2951- yield AsyncGroup (metadata = meta_out , store_path = store_path / node_name )
2974+ yield AsyncGroup (metadata = meta_out , store_path = node_store_path )
29522975 else :
2953- yield AsyncArray (metadata = meta_out , store_path = store_path / node_name )
2976+ yield AsyncArray (metadata = meta_out , store_path = node_store_path )
29542977 continue
29552978
29562979
@@ -2963,13 +2986,24 @@ def _get_roots(
29632986 """
29642987 Return the keys of the root(s) of the hierarchy
29652988 """
2989+ if "" in data :
2990+ return ("" ,)
29662991 keys_split = sorted ((key .split ("/" ) for key in data ), key = len )
29672992 groups : defaultdict [int , list [str ]] = defaultdict (list )
29682993 for key_split in keys_split :
29692994 groups [len (key_split )].append ("/" .join (key_split ))
29702995 return tuple (groups [min (groups .keys ())])
29712996
29722997
2998+ def _join_paths (paths : Iterable [str ]) -> str :
2999+ """
3000+ Filter out instances of '' and join the remaining strings with '/'.
3001+
3002+ Because the root node of a zarr hierarchy is represented by an empty string,
3003+ """
3004+ return "/" .join (filter (lambda v : v != "" , paths ))
3005+
3006+
29733007def _parse_hierarchy_dict (
29743008 data : Mapping [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ],
29753009) -> dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ]:
@@ -2993,7 +3027,7 @@ def _parse_hierarchy_dict(
29933027 # Create a copy of the input dict
29943028 out : dict [str , GroupMetadata | ArrayV2Metadata | ArrayV3Metadata ] = {** data }
29953029
2996- observed_zarr_formats : dict [ZarrFormat , list [str ]] = {2 : [], 3 : []}
3030+ observed_zarr_formats : dict [ZarrFormat , list [str | None ]] = {2 : [], 3 : []}
29973031
29983032 # We will iterate over the dict again, but a full pass here ensures that the error message
29993033 # is comprehensive, and I think the performance cost will be negligible.
@@ -3011,23 +3045,30 @@ def _parse_hierarchy_dict(
30113045 raise ValueError (msg )
30123046
30133047 for k , v in data .items ():
3014- # TODO: ensure that the key is a valid path
3015- # Split the key into its path components
3016- key_split = k .split ("/" )
3017-
3018- # Iterate over the intermediate path components
3019- * subpaths , _ = accumulate (key_split , lambda a , b : f"{ a } /{ b } " )
3020- for subpath in subpaths :
3021- # If a component is not already in the output dict, add a group
3022- if subpath not in out :
3023- out [subpath ] = GroupMetadata (zarr_format = v .zarr_format )
3024- else :
3025- if not isinstance (out [subpath ], GroupMetadata ):
3026- msg = (
3027- f"The node at { subpath } contains other nodes, but it is not a Zarr group. "
3028- "This is invalid. Only Zarr groups can contain other nodes."
3029- )
3030- raise ValueError (msg )
3048+ if k is None :
3049+ # root node
3050+ pass
3051+ else :
3052+ if k .startswith ("/" ):
3053+ msg = f"Keys of hierarchy dicts must be relative paths, i.e. they cannot start with '/'. Got { k } , which violates this rule."
3054+ raise ValueError (k )
3055+ # TODO: ensure that the key is a valid path
3056+ # Split the key into its path components
3057+ key_split = k .split ("/" )
3058+
3059+ # Iterate over the intermediate path components
3060+ * subpaths , _ = accumulate (key_split , lambda a , b : f"{ a } /{ b } " )
3061+ for subpath in subpaths :
3062+ # If a component is not already in the output dict, add a group
3063+ if subpath not in out :
3064+ out [subpath ] = GroupMetadata (zarr_format = v .zarr_format )
3065+ else :
3066+ if not isinstance (out [subpath ], GroupMetadata ):
3067+ msg = (
3068+ f"The node at { subpath } contains other nodes, but it is not a Zarr group. "
3069+ "This is invalid. Only Zarr groups can contain other nodes."
3070+ )
3071+ raise ValueError (msg )
30313072
30323073 return out
30333074
@@ -3258,7 +3299,7 @@ def _persist_metadata(
32583299
32593300 to_save = metadata .to_buffer_dict (default_buffer_prototype ())
32603301 return tuple (
3261- _set_return_key (store = store , key = f" { path } / { key } " . lstrip ( "/" ), value = value , replace = True )
3302+ _set_return_key (store = store , key = _join_paths ([ path , key ] ), value = value , replace = True )
32623303 for key , value in to_save .items ()
32633304 )
32643305
@@ -3278,7 +3319,7 @@ async def _from_flat(
32783319 "The input does not specify a root node. "
32793320 "This function can only create hierarchies that contain a root node, which is "
32803321 "defined as a group that is ancestral to all the other arrays and "
3281- "groups in the hierarchy."
3322+ "groups in the hierarchy, or a single array ."
32823323 )
32833324 raise ValueError (msg )
32843325 else :
@@ -3292,7 +3333,9 @@ async def _from_flat(
32923333 store_path = store_path , nodes = nodes , semaphore = semaphore , overwrite = overwrite
32933334 )
32943335 }
3295- root_group = nodes_created [root ]
3336+ # the names of the created nodes will be relative to the store_path instance
3337+ root_relative_to_store_path = _join_paths ([store_path .path , root ])
3338+ root_group = nodes_created [root_relative_to_store_path ]
32963339 if not isinstance (root_group , AsyncGroup ):
32973340 raise TypeError ("Invalid root node returned from create_hierarchy." )
32983341 return root_group
0 commit comments