Skip to content

Commit 6507e43

Browse files
committed
refactor sync / async functions, and make tests more compact accordingly
1 parent 5282534 commit 6507e43

File tree

2 files changed

+297
-180
lines changed

2 files changed

+297
-180
lines changed

src/zarr/core/group.py

Lines changed: 133 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from zarr.core.config import config
5858
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
5959
from zarr.core.metadata.v3 import V3JsonEncoder
60-
from zarr.core.sync import SyncMixin, sync
60+
from zarr.core.sync import SyncMixin, _collect_aiterator, sync
6161
from zarr.errors import (
6262
ContainsArrayError,
6363
ContainsGroupError,
@@ -1422,7 +1422,7 @@ async def create_hierarchy(
14221422
An asynchronous iterator over the created arrays and / or groups.
14231423
"""
14241424
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
1425-
async for node in create_hierarchy(
1425+
async for node in create_hierarchy_a(
14261426
store=self.store,
14271427
path=self.path,
14281428
nodes=nodes,
@@ -2837,7 +2837,7 @@ def array(
28372837
)
28382838

28392839

2840-
async def create_hierarchy(
2840+
async def create_hierarchy_a(
28412841
*,
28422842
store: Store,
28432843
path: str,
@@ -2848,11 +2848,10 @@ async def create_hierarchy(
28482848
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
28492849
"""
28502850
Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
2851-
``nodes`` will be created as needed.
2851+
will be created as needed.
28522852
28532853
This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy
2854-
concurrently. The groups and arrays in the hierarchy are created in a single pass, and the
2855-
function yields the created nodes in the order they are created.
2854+
concurrently. AsyncArrays and AsyncGroups are yielded in the order they are created.
28562855
28572856
Parameters
28582857
----------
@@ -2958,11 +2957,58 @@ async def create_hierarchy(
29582957
k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups
29592958
}
29602959

2961-
async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore):
2960+
async for node in create_nodes_a(
2961+
store=store, path=path, nodes=nodes_parsed, semaphore=semaphore
2962+
):
29622963
yield node
29632964

29642965

2965-
async def create_nodes(
2966+
def create_hierarchy(
2967+
store: Store,
2968+
path: str,
2969+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
2970+
overwrite: bool = False,
2971+
allow_root: bool = True,
2972+
) -> Iterator[Group | Array]:
2973+
"""
2974+
Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
2975+
will be created as needed.
2976+
2977+
This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy
2978+
concurrently. Arrays and Groups are yielded in the order they are created.
2979+
2980+
Parameters
2981+
----------
2982+
store : Store
2983+
The storage backend to use.
2984+
path : str
2985+
The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with
2986+
``path`` prior to creating nodes.
2987+
nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
2988+
A dictionary defining the hierarchy. The keys are the paths of the nodes
2989+
in the hierarchy, and the values are the metadata of the nodes. The
2990+
metadata must be either an instance of GroupMetadata, ArrayV3Metadata
2991+
or ArrayV2Metadata.
2992+
allow_root : bool
2993+
Whether to allow a root node to be created. If ``False``, attempting to create a root node
2994+
will result in an error. Use this option when calling this function as part of a method
2995+
defined on ``AsyncGroup`` instances, because in this case the root node has already been
2996+
created.
2997+
2998+
Yields
2999+
------
3000+
Group | Array
3001+
The created nodes in the order they are created.
3002+
"""
3003+
coro = create_hierarchy_a(
3004+
store=store, path=path, nodes=nodes, overwrite=overwrite, allow_root=allow_root
3005+
)
3006+
3007+
for result in sync(_collect_aiterator(coro)):
3008+
yield _parse_async_node(result)
3009+
3010+
3011+
async def create_nodes_a(
29663012
*,
29673013
store: Store,
29683014
path: str,
@@ -3056,14 +3102,53 @@ async def create_nodes(
30563102
meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys
30573103

30583104
if meta_done and attrs_done:
3059-
node_store_path = StorePath(store=store, path=path) / node_name
3060-
if isinstance(meta_out, GroupMetadata):
3061-
yield AsyncGroup(metadata=meta_out, store_path=node_store_path)
3062-
else:
3063-
yield AsyncArray(metadata=meta_out, store_path=node_store_path)
3105+
yield _build_node(
3106+
store=store, path=_join_paths([path, node_name]), metadata=meta_out
3107+
)
3108+
30643109
continue
30653110

30663111

3112+
def create_nodes(
3113+
*,
3114+
store: Store,
3115+
path: str,
3116+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
3117+
semaphore: asyncio.Semaphore | None = None,
3118+
) -> Iterator[Group | Array]:
3119+
"""Create a collection of arrays and / or groups concurrently.
3120+
3121+
Note: no attempt is made to validate that these arrays and / or groups collectively form a
3122+
valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that
3123+
the ``nodes`` parameter satisfies any correctness constraints.
3124+
3125+
Parameters
3126+
----------
3127+
store : Store
3128+
The storage backend to use.
3129+
path : str
3130+
The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with
3131+
``path`` prior to creating nodes.
3132+
nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
3133+
A dictionary defining the hierarchy. The keys are the paths of the nodes
3134+
in the hierarchy, and the values are the metadata of the nodes. The
3135+
metadata must be either an instance of GroupMetadata, ArrayV3Metadata
3136+
or ArrayV2Metadata.
3137+
semaphore : asyncio.Semaphore | None
3138+
An optional semaphore to limit the number of concurrent tasks. If not
3139+
provided, the number of concurrent tasks is unlimited.
3140+
3141+
Yields
3142+
------
3143+
Group | Array
3144+
The created nodes in the order they are created.
3145+
"""
3146+
coro = create_nodes_a(store=store, path=path, nodes=nodes, semaphore=semaphore)
3147+
3148+
for result in sync(_collect_aiterator(coro)):
3149+
yield _parse_async_node(result)
3150+
3151+
30673152
T = TypeVar("T")
30683153

30693154

@@ -3393,34 +3478,31 @@ def _build_metadata_v2(
33933478
return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json})
33943479

33953480

3396-
def _build_node_v3(
3397-
*,
3398-
store: Store,
3399-
path: str,
3400-
metadata: ArrayV3Metadata | GroupMetadata,
3401-
) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
3402-
"""
3403-
Take a metadata object and return a node (AsyncArray or AsyncGroup).
3404-
"""
3405-
store_path = StorePath(store=store, path=path)
3406-
match metadata:
3407-
case ArrayV3Metadata():
3408-
return AsyncArray(metadata, store_path=store_path)
3409-
case GroupMetadata():
3410-
return AsyncGroup(metadata, store_path=store_path)
3411-
case _: # pragma: no cover
3412-
raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover
3481+
@overload
3482+
def _build_node(
3483+
*, store: Store, path: str, metadata: ArrayV2Metadata
3484+
) -> AsyncArray[ArrayV2Metadata]: ...
34133485

34143486

3415-
def _build_node_v2(
3416-
*, store: Store, path: str, metadata: ArrayV2Metadata | GroupMetadata
3417-
) -> AsyncArray[ArrayV2Metadata] | AsyncGroup:
3487+
@overload
3488+
def _build_node(
3489+
*, store: Store, path: str, metadata: ArrayV3Metadata
3490+
) -> AsyncArray[ArrayV3Metadata]: ...
3491+
3492+
3493+
@overload
3494+
def _build_node(*, store: Store, path: str, metadata: GroupMetadata) -> AsyncGroup: ...
3495+
3496+
3497+
def _build_node(
3498+
*, store: Store, path: str, metadata: ArrayV3Metadata | ArrayV2Metadata | GroupMetadata
3499+
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup:
34183500
"""
34193501
Take a metadata object and return a node (AsyncArray or AsyncGroup).
34203502
"""
34213503
store_path = StorePath(store=store, path=path)
34223504
match metadata:
3423-
case ArrayV2Metadata():
3505+
case ArrayV2Metadata() | ArrayV3Metadata():
34243506
return AsyncArray(metadata, store_path=store_path)
34253507
case GroupMetadata():
34263508
return AsyncGroup(metadata, store_path=store_path)
@@ -3433,22 +3515,22 @@ async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata]
34333515
Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath.
34343516
"""
34353517
metadata = await _read_metadata_v2(store=store, path=path)
3436-
return _build_node_v2(store=store, path=path, metadata=metadata)
3518+
return _build_node(store=store, path=path, metadata=metadata)
34373519

34383520

34393521
async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
34403522
"""
34413523
Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath.
34423524
"""
34433525
metadata = await _read_metadata_v3(store=store, path=path)
3444-
return _build_node_v3(store=store, path=path, metadata=metadata)
3526+
return _build_node(store=store, path=path, metadata=metadata)
34453527

34463528

3447-
async def _read_node(
3529+
async def _read_node_a(
34483530
store: Store, path: str, zarr_format: ZarrFormat
34493531
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup:
34503532
"""
3451-
Read an AsyncArray or AsyncGroup from a location defined by a StorePath.
3533+
Read an AsyncArray or AsyncGroup from a path in a Store.
34523534
"""
34533535
match zarr_format:
34543536
case 2:
@@ -3459,6 +3541,14 @@ async def _read_node(
34593541
raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover
34603542

34613543

3544+
def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group:
3545+
"""
3546+
Read an Array or Group from a path in a Store.
3547+
"""
3548+
3549+
return _parse_async_node(sync(_read_node_a(store=store, path=path, zarr_format=zarr_format)))
3550+
3551+
34623552
async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str:
34633553
"""
34643554
Write a value to storage at the given key. The key is returned.
@@ -3493,7 +3583,7 @@ def _persist_metadata(
34933583
)
34943584

34953585

3496-
async def _create_rooted_hierarchy(
3586+
async def _create_rooted_hierarchy_a(
34973587
*,
34983588
store: Store,
34993589
path: str,
@@ -3521,7 +3611,7 @@ async def _create_rooted_hierarchy(
35213611

35223612
nodes_created = {
35233613
x.path: x
3524-
async for x in create_hierarchy(
3614+
async for x in create_hierarchy_a(
35253615
store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite
35263616
)
35273617
}
@@ -3530,7 +3620,7 @@ async def _create_rooted_hierarchy(
35303620
return nodes_created[root_relative_to_store_path]
35313621

35323622

3533-
def _create_rooted_hierarchy_sync(
3623+
def _create_rooted_hierarchy(
35343624
*,
35353625
store: Store,
35363626
path: str,
@@ -3542,11 +3632,6 @@ def _create_rooted_hierarchy_sync(
35423632
``_create_rooted_hierarchy`` and waits for the result.
35433633
"""
35443634
async_node = sync(
3545-
_create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite)
3635+
_create_rooted_hierarchy_a(store=store, path=path, nodes=nodes, overwrite=overwrite)
35463636
)
3547-
if isinstance(async_node, AsyncGroup):
3548-
return Group(async_node)
3549-
elif isinstance(async_node, AsyncArray):
3550-
return Array(async_node)
3551-
else:
3552-
raise TypeError(f"Unexpected node type: {type(async_node)}") # pragma: no cover
3637+
return _parse_async_node(async_node)

0 commit comments

Comments
 (0)