1919from zarr .abc .store import Store , set_or_delete
2020from zarr .core .array import Array , AsyncArray , _build_parents
2121from zarr .core .attributes import Attributes
22- from zarr .core .buffer import default_buffer_prototype , Buffer
22+ from zarr .core .buffer import Buffer , default_buffer_prototype
2323from zarr .core .common import (
2424 JSON ,
2525 ZARR_JSON ,
@@ -645,12 +645,10 @@ async def getitem(
645645 raise KeyError (key )
646646 else :
647647 zarr_json = json .loads (zarr_json_bytes .to_bytes ())
648- if zarr_json ["node_type" ] == "group" :
649- return type (self ).from_dict (store_path , zarr_json )
650- elif zarr_json ["node_type" ] == "array" :
651- return AsyncArray .from_dict (store_path , zarr_json )
652- else :
653- raise ValueError (f"unexpected node_type: { zarr_json ['node_type' ]} " )
648+ metadata = build_metadata_v3 (zarr_json )
649+ node = build_node_v3 (metadata , store_path )
650+ return node
651+
654652 elif self .metadata .zarr_format == 2 :
655653 # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
656654 # This guarantees that we will always make at least one extra request to the store
@@ -1154,74 +1152,14 @@ async def members(
11541152 async for item in self ._members (max_depth = max_depth ):
11551153 yield item
11561154
1157- async def _members_old (
1158- self , max_depth : int | None , current_depth : int
1159- ) -> AsyncGenerator [
1160- tuple [str , AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ] | AsyncGroup ],
1161- None ,
1162- ]:
1163- if self .metadata .consolidated_metadata is not None :
1164- # we should be able to do members without any additional I/O
1165- members = self ._members_consolidated (max_depth )
1166- for member in members :
1167- yield member
1168- return
1169-
1170- if not self .store_path .store .supports_listing :
1171- msg = (
1172- f"The store associated with this group ({ type (self .store_path .store )} ) "
1173- "does not support listing, "
1174- "specifically via the `list_dir` method. "
1175- "This function requires a store that supports listing."
1176- )
1177-
1178- raise ValueError (msg )
1179- # would be nice to make these special keys accessible programmatically,
1180- # and scoped to specific zarr versions
1181- # especially true for `.zmetadata` which is configurable
1182- _skip_keys = ("zarr.json" , ".zgroup" , ".zattrs" , ".zmetadata" )
1183-
1184- # hmm lots of I/O and logic interleaved here.
1185- # We *could* have an async gen over self.metadata.consolidated_metadata.metadata.keys()
1186- # and plug in here. `getitem` will skip I/O.
1187- # Kinda a shame to have all the asyncio task overhead though, when it isn't needed.
1188-
1189- async for key in self .store_path .store .list_dir (self .store_path .path ):
1190- if key in _skip_keys :
1191- continue
1192- try :
1193- obj = await self .getitem (key )
1194- yield (key , obj )
1195-
1196- if (
1197- ((max_depth is None ) or (current_depth < max_depth ))
1198- and hasattr (obj .metadata , "node_type" )
1199- and obj .metadata .node_type == "group"
1200- ):
1201- # the assert is just for mypy to know that `obj.metadata.node_type`
1202- # implies an AsyncGroup, not an AsyncArray
1203- assert isinstance (obj , AsyncGroup )
1204- async for child_key , val in obj ._members (
1205- max_depth = max_depth ):
1206- yield f"{ key } /{ child_key } " , val
1207- except KeyError :
1208- # keyerror is raised when `key` names an object (in the object storage sense),
1209- # as opposed to a prefix, in the store under the prefix associated with this group
1210- # in which case `key` cannot be the name of a sub-array or sub-group.
1211- warnings .warn (
1212- f"Object at { key } is not recognized as a component of a Zarr hierarchy." ,
1213- UserWarning ,
1214- stacklevel = 1 ,
1215- )
1216-
12171155 def _members_consolidated (
12181156 self , max_depth : int | None , prefix : str = ""
12191157 ) -> Generator [
12201158 tuple [str , AsyncArray [ArrayV2Metadata ] | AsyncArray [ArrayV3Metadata ] | AsyncGroup ],
12211159 None ,
12221160 ]:
12231161 consolidated_metadata = self .metadata .consolidated_metadata
1224-
1162+
12251163 do_recursion = max_depth is None or max_depth > 0
12261164
12271165 # we kind of just want the top-level keys.
@@ -1233,23 +1171,23 @@ def _members_consolidated(
12331171 key = f"{ prefix } /{ key } " .lstrip ("/" )
12341172 yield key , obj
12351173
1236- if do_recursion and isinstance (
1237- obj , AsyncGroup
1238- ):
1174+ if do_recursion and isinstance (obj , AsyncGroup ):
12391175 if max_depth is None :
1240- new_depth = None
1176+ new_depth = None
12411177 else :
12421178 new_depth = max_depth - 1
12431179 yield from obj ._members_consolidated (new_depth , prefix = key )
1244-
1180+
12451181 async def _members (
1246- self ,
1247- max_depth : int | None ) -> AsyncGenerator [tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None ]:
1182+ self , max_depth : int | None
1183+ ) -> AsyncGenerator [
1184+ tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None
1185+ ]:
12481186 skip_keys : tuple [str , ...]
12491187 if self .metadata .zarr_format == 2 :
1250- skip_keys = (' .zattrs' , ' .zgroup' , ' .zarray' , ' .zmetadata' )
1188+ skip_keys = (" .zattrs" , " .zgroup" , " .zarray" , " .zmetadata" )
12511189 elif self .metadata .zarr_format == 3 :
1252- skip_keys = (' zarr.json' ,)
1190+ skip_keys = (" zarr.json" ,)
12531191 else :
12541192 raise ValueError (f"Unknown Zarr format: { self .metadata .zarr_format } " )
12551193
@@ -1268,7 +1206,9 @@ async def _members(
12681206 )
12691207
12701208 raise ValueError (msg )
1271- async for member in iter_members_deep (self , max_depth = max_depth , prefix = self .basename , skip_keys = skip_keys ):
1209+ async for member in iter_members_deep (
1210+ self , max_depth = max_depth , prefix = self .basename , skip_keys = skip_keys
1211+ ):
12721212 yield member
12731213
12741214 async def keys (self ) -> AsyncGenerator [str , None ]:
@@ -1913,31 +1853,31 @@ async def members_recursive(
19131853 key_body = "/" .join (key .split ("/" )[:- 1 ])
19141854
19151855 if blob is not None :
1916- resolved_metadata = resolve_metadata_v3 (blob .to_bytes ())
1856+ resolved_metadata = build_metadata_v3 (blob .to_bytes ())
19171857 members_flat += ((key_body , resolved_metadata ),)
19181858 if isinstance (resolved_metadata , GroupMetadata ):
1919- to_recurse .append (
1920- members_recursive (store , key_body ))
1859+ to_recurse .append (members_recursive (store , key_body ))
19211860
19221861 subgroups = await asyncio .gather (* to_recurse )
19231862 members_flat += tuple (subgroup for subgroup in subgroups )
19241863
19251864 return members_flat
19261865
1866+
19271867async def iter_members (
1928- node : AsyncGroup ,
1929- skip_keys : tuple [str , ...]
1930- ) -> AsyncGenerator [tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None ]:
1868+ node : AsyncGroup , skip_keys : tuple [str , ...]
1869+ ) -> AsyncGenerator [
1870+ tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None
1871+ ]:
19311872 """
19321873 Iterate over the arrays and groups contained in a group.
19331874 """
1934-
1875+
19351876 # retrieve keys from storage
19361877 keys = [key async for key in node .store .list_dir (node .path )]
19371878 keys_filtered = tuple (filter (lambda v : v not in skip_keys , keys ))
19381879
1939- node_tasks = tuple (asyncio .create_task (
1940- node .getitem (key ), name = key ) for key in keys_filtered )
1880+ node_tasks = tuple (asyncio .create_task (node .getitem (key ), name = key ) for key in keys_filtered )
19411881
19421882 for fetched_node_coro in asyncio .as_completed (node_tasks ):
19431883 try :
@@ -1958,15 +1898,14 @@ async def iter_members(
19581898 case _:
19591899 raise ValueError (f"Unexpected type: { type (fetched_node )} " )
19601900
1901+
19611902async def iter_members_deep (
1962- group : AsyncGroup ,
1963- * ,
1964- prefix : str ,
1965- max_depth : int | None ,
1966- skip_keys : tuple [str , ...]
1967- ) -> AsyncGenerator [tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None ]:
1903+ group : AsyncGroup , * , prefix : str , max_depth : int | None , skip_keys : tuple [str , ...]
1904+ ) -> AsyncGenerator [
1905+ tuple [str , AsyncArray [ArrayV3Metadata ] | AsyncArray [ArrayV2Metadata ] | AsyncGroup ], None
1906+ ]:
19681907 """
1969- Iterate over the arrays and groups contained in a group, and optionally the
1908+ Iterate over the arrays and groups contained in a group, and optionally the
19701909 arrays and groups contained in those groups.
19711910 """
19721911
@@ -1978,34 +1917,65 @@ async def iter_members_deep(
19781917 new_depth = max_depth - 1
19791918
19801919 async for name , node in iter_members (group , skip_keys = skip_keys ):
1981- yield f' { prefix } /{ name } ' .lstrip ('/' ), node
1920+ yield f" { prefix } /{ name } " .lstrip ("/" ), node
19821921 if isinstance (node , AsyncGroup ) and do_recursion :
1983- to_recurse .append (iter_members_deep (
1984- node ,
1985- max_depth = new_depth ,
1986- prefix = f' { prefix } / { name } ' ,
1987- skip_keys = skip_keys ) )
1922+ to_recurse .append (
1923+ iter_members_deep (
1924+ node , max_depth = new_depth , prefix = f" { prefix } / { name } " , skip_keys = skip_keys
1925+ )
1926+ )
19881927
19891928 for subgroup in to_recurse :
19901929 async for name , node in subgroup :
19911930 yield name , node
1992-
19931931
1994- def resolve_metadata_v2 (blobs : tuple [str | bytes | bytearray , str | bytes | bytearray ]) -> ArrayV2Metadata | GroupMetadata :
1932+
1933+ def resolve_metadata_v2 (
1934+ blobs : tuple [str | bytes | bytearray , str | bytes | bytearray ],
1935+ ) -> ArrayV2Metadata | GroupMetadata :
19951936 zarr_metadata = json .loads (blobs [0 ])
19961937 attrs = json .loads (blobs [1 ])
1997- if ' shape' in zarr_metadata :
1998- return ArrayV2Metadata .from_dict (zarr_metadata | {' attrs' : attrs })
1938+ if " shape" in zarr_metadata :
1939+ return ArrayV2Metadata .from_dict (zarr_metadata | {" attrs" : attrs })
19991940 else :
2000- return GroupMetadata .from_dict (zarr_metadata | {'attrs' : attrs })
1941+ return GroupMetadata .from_dict (zarr_metadata | {"attrs" : attrs })
1942+
20011943
2002- def resolve_metadata_v3 (blob : str | bytes | bytearray ) -> ArrayV3Metadata | GroupMetadata :
2003- zarr_json = json .loads (blob )
1944+ def build_metadata_v3 (zarr_json : dict [str , Any ]) -> ArrayV3Metadata | GroupMetadata :
1945+ """
1946+ Take a dict and convert it into the correct metadata type.
1947+ """
20041948 if "node_type" not in zarr_json :
2005- raise ValueError ("missing node_type in metadata document" )
2006- if zarr_json ["node_type" ] == "array" :
2007- return ArrayV3Metadata .from_dict (zarr_json )
2008- elif zarr_json ["node_type" ] == "group" :
2009- return GroupMetadata .from_dict (zarr_json )
2010- else :
2011- raise ValueError ("invalid node_type in metadata document" )
1949+ raise KeyError ("missing `node_type` key in metadata document." )
1950+ match zarr_json :
1951+ case {"node_type" : "array" }:
1952+ return ArrayV3Metadata .from_dict (zarr_json )
1953+ case {"node_type" : "group" }:
1954+ return GroupMetadata .from_dict (zarr_json )
1955+ case _:
1956+ raise ValueError ("invalid value for `node_type` key in metadata document" )
1957+
1958+
1959+ def build_metadata_v2 (
1960+ zarr_json : dict [str , Any ], attrs_json : dict [str , Any ]
1961+ ) -> ArrayV2Metadata | GroupMetadata :
1962+ match zarr_json :
1963+ case {"shape" : _}:
1964+ return ArrayV2Metadata .from_dict (zarr_json | {"attributes" : attrs_json })
1965+ case _:
1966+ return GroupMetadata .from_dict (zarr_json | {"attributes" : attrs_json })
1967+
1968+
1969+ def build_node_v3 (
1970+ metadata : ArrayV3Metadata | GroupMetadata , store_path : StorePath
1971+ ) -> AsyncArray [ArrayV3Metadata ] | AsyncGroup :
1972+ """
1973+ Take a metadata object and return a node (AsyncArray or AsyncGroup).
1974+ """
1975+ match metadata :
1976+ case ArrayV3Metadata ():
1977+ return AsyncArray (metadata , store_path = store_path )
1978+ case GroupMetadata ():
1979+ return AsyncGroup (metadata , store_path = store_path )
1980+ case _:
1981+ raise ValueError (f"Unexpected metadata type: { type (metadata )} " )
0 commit comments