Skip to content

Commit 116ab87

Browse files
committed
better concurrency for v2
1 parent 089feef commit 116ab87

File tree

3 files changed

+270
-52
lines changed

3 files changed

+270
-52
lines changed

src/zarr/core/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def reset(self) -> None:
9494
],
9595
},
9696
},
97-
"async": {"concurrency": 10, "timeout": None},
97+
"async": {"concurrency": 256, "timeout": None},
9898
"threading": {"max_workers": None},
9999
"json_indent": 2,
100100
"codec_pipeline": {

src/zarr/core/group.py

Lines changed: 173 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import itertools
56
import json
67
import logging
78
import warnings
89
from collections import defaultdict
910
from dataclasses import asdict, dataclass, field, fields, replace
10-
from functools import partial
1111
from itertools import accumulate
1212
from pathlib import PurePosixPath
13-
from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload
13+
from typing import (
14+
TYPE_CHECKING,
15+
Literal,
16+
Self,
17+
TypeVar,
18+
assert_never,
19+
cast,
20+
overload,
21+
)
1422

1523
import numpy as np
1624
import numpy.typing as npt
@@ -51,7 +59,7 @@
5159
from zarr.core.config import config
5260
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
5361
from zarr.core.metadata.v3 import V3JsonEncoder
54-
from zarr.core.sync import SyncMixin, _with_semaphore, sync
62+
from zarr.core.sync import SyncMixin, sync
5563
from zarr.errors import MetadataValidationError
5664
from zarr.storage import StoreLike, StorePath
5765
from zarr.storage._common import ensure_no_existing_node, make_store_path
@@ -60,6 +68,7 @@
6068
from collections.abc import (
6169
AsyncGenerator,
6270
AsyncIterator,
71+
Coroutine,
6372
Generator,
6473
Iterable,
6574
Iterator,
@@ -431,21 +440,32 @@ class AsyncGroup:
431440
@classmethod
432441
async def from_flat(
433442
cls,
434-
store: StoreLike,
435-
*,
436-
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
437-
overwrite: bool = False) -> Self:
438-
443+
store: StoreLike,
444+
*,
445+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
446+
overwrite: bool = False,
447+
) -> Self:
448+
if not _is_rooted(nodes):
449+
msg = (
450+
"The input does not specify a root node. ",
451+
"This function can only create hierarchies that contain a root node, which is ",
452+
"defined as a group that is ancestral to all the other arrays and ",
453+
"groups in the hierarchy.",
454+
)
455+
raise ValueError(msg)
456+
439457
if overwrite:
440-
store_path = await make_store_path(store, mode='w')
458+
store_path = await make_store_path(store, mode="w")
441459
else:
442-
store_path = await make_store_path(store, mode='w-')
460+
store_path = await make_store_path(store, mode="w-")
461+
443462
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
444-
445-
nodes_created = {x.name: x async for x in create_hierarchy(
446-
store_path=store_path, nodes=nodes, semaphore=semaphore
447-
)}
448-
return nodes_created['']
463+
464+
nodes_created = {
465+
x.name: x
466+
async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore)
467+
}
468+
# TODO: make this work
449469

450470
@classmethod
451471
async def from_store(
@@ -1743,17 +1763,18 @@ async def move(self, source: str, dest: str) -> None:
17431763
@dataclass(frozen=True)
17441764
class Group(SyncMixin):
17451765
_async_group: AsyncGroup
1746-
1766+
17471767
@classmethod
17481768
def from_flat(
1749-
cls,
1769+
cls,
17501770
store: StoreLike,
1751-
*,
1771+
*,
17521772
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
1753-
overwrite: bool = False) -> Group:
1773+
overwrite: bool = False,
1774+
) -> Group:
17541775
nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite))
17551776
# return the root node of the hierarchy
1756-
return nodes['']
1777+
return nodes[""]
17571778

17581779
@classmethod
17591780
def from_store(
@@ -2110,17 +2131,28 @@ def create_hierarchy(
21102131
21112132
Parameters
21122133
----------
2113-
nodes : A dictionary representing the hierarchy to create
2134+
nodes : A dictionary representing the hierarchy to create. The keys should be relative paths
2135+
and the values should be the metadata for the arrays or groups to create.
21142136
21152137
Returns
21162138
-------
2117-
A dict containing the created nodes.The keys are the same as th
2118-
"""
2139+
A dict containing the created nodes, with the same keys as the input
2140+
"""
2141+
# check that all the nodes have the same zarr_format as Self
2142+
for key, value in nodes.items():
2143+
if value.zarr_format != self.metadata.zarr_format:
2144+
msg = (
2145+
"The zarr_format of the nodes must be the same as the parent group. "
2146+
f"The node at {key} has zarr_format {value.zarr_format}, but the parent group"
2147+
f" has zarr_format {self.metadata.zarr_format}."
2148+
)
2149+
raise ValueError(msg)
21192150
nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes))
21202151
if self.path == "":
21212152
root = "/"
21222153
else:
21232154
root = self.path
2155+
# TODO: make this safe against invalid path inputs
21242156
return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created}
21252157

21262158
def keys(self) -> Generator[str, None]:
@@ -2859,6 +2891,7 @@ def array(
28592891

28602892
async def _save_metadata(
28612893
node: AsyncArray[Any] | AsyncGroup,
2894+
overwrite: bool,
28622895
) -> AsyncArray[Any] | AsyncGroup:
28632896
"""
28642897
Save the metadata for an array or group, and return the array or group
@@ -2878,6 +2911,7 @@ async def create_hierarchy(
28782911
store_path: StorePath,
28792912
nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata],
28802913
semaphore: asyncio.Semaphore | None = None,
2914+
overwrite: bool = False,
28812915
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
28822916
"""
28832917
Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input
@@ -2906,42 +2940,81 @@ async def create_hierarchy(
29062940
The created nodes in the order they are created.
29072941
"""
29082942
nodes_parsed = _parse_hierarchy_dict(nodes)
2943+
29092944
async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore):
29102945
yield node
29112946

29122947

29132948
async def create_nodes(
29142949
*,
29152950
store_path: StorePath,
2916-
nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata],
2951+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
29172952
semaphore: asyncio.Semaphore | None = None,
2918-
) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]:
2953+
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
29192954
"""
2920-
Create a collection of arrays and groups concurrently and atomically. To ensure atomicity,
2955+
Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity,
29212956
no attempt is made to ensure that intermediate groups are created.
29222957
"""
2923-
create_tasks = []
2958+
ctx: asyncio.Semaphore | contextlib.nullcontext[None]
2959+
if semaphore is None:
2960+
ctx = contextlib.nullcontext()
2961+
else:
2962+
ctx = semaphore
2963+
2964+
create_tasks: list[Coroutine[None, None, str]] = []
2965+
29242966
for key, value in nodes.items():
2925-
new_store_path = store_path / key
2926-
node: AsyncArray[Any] | AsyncGroup
2927-
match value:
2928-
case ArrayV3Metadata() | ArrayV2Metadata():
2929-
node = AsyncArray(value, store_path=new_store_path)
2930-
case GroupMetadata():
2931-
node = AsyncGroup(value, store_path=new_store_path)
2932-
case _:
2933-
raise ValueError(f"Unexpected metadata type {type(value)}")
2934-
partial_func = partial(_save_metadata, node)
2935-
fut = _with_semaphore(partial_func, semaphore)
2936-
create_tasks.append(fut)
2967+
create_tasks.extend(
2968+
_prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value)
2969+
)
2970+
2971+
created_keys = []
2972+
async with ctx:
2973+
for coro in asyncio.as_completed(create_tasks):
2974+
created_key = await coro
2975+
relative_path = PurePosixPath(created_key).relative_to(store_path.path)
2976+
created_keys.append(str(relative_path))
2977+
# convert /foo/bar/baz/.zattrs to bar/baz
2978+
node_name = str(relative_path.parent)
2979+
meta_out = nodes[node_name]
2980+
2981+
if meta_out.zarr_format == 3:
2982+
if isinstance(meta_out, GroupMetadata):
2983+
yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name)
2984+
else:
2985+
yield AsyncArray(metadata=meta_out, store_path=store_path / node_name)
2986+
else:
2987+
# For zarr v2
2988+
# we only want to yield when both the metadata and attributes are created
2989+
# so we track which keys have been created, and wait for both the meta key and
2990+
# the attrs key to be created before yielding back the AsyncArray / AsyncGroup
2991+
2992+
attrs_done = f"{node_name}/.zattrs" in created_keys
2993+
2994+
if isinstance(meta_out, GroupMetadata):
2995+
meta_done = f"{node_name}/.zgroup" in created_keys
2996+
else:
2997+
meta_done = f"{node_name}/.zarray" in created_keys
29372998

2938-
for coro in asyncio.as_completed(create_tasks):
2939-
yield await coro
2999+
if meta_done and attrs_done:
3000+
if isinstance(meta_out, GroupMetadata):
3001+
yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name)
3002+
else:
3003+
yield AsyncArray(metadata=meta_out, store_path=store_path / node_name)
29403004

29413005

29423006
T = TypeVar("T")
29433007

29443008

3009+
def _is_rooted(data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]) -> bool:
3010+
"""
3011+
Check if the data describes a hierarchy that's rooted, which means there is a single node with
3012+
the least number of components in its key
3013+
"""
3014+
# a dict
3015+
return False
3016+
3017+
29453018
def _parse_hierarchy_dict(
29463019
data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
29473020
) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]:
@@ -2953,19 +3026,54 @@ def _parse_hierarchy_dict(
29533026
For example, an input of {'a/b/c': ...} will result in a return value of
29543027
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
29553028
2956-
This function is useful for ensuring that the input to create_hierarchy is a complete
3029+
The input is also checked for the following conditions, and an error is raised if any
3030+
of them are violated:
3031+
3032+
- No arrays can contain group or arrays (i.e., all arrays must be leaf nodes).
3033+
- All arrays and groups must have the same ``zarr_format`` value.
3034+
3035+
This function ensures that the input is transformed into a specification of a complete and valid
29573036
Zarr hierarchy.
29583037
"""
29593038
# Create a copy of the input dict
29603039
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data}
3040+
3041+
observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []}
3042+
3043+
# We will iterate over the dict again, but a full pass here ensures that the error message
3044+
# is comprehensive, and I think the performance cost will be negligible.
3045+
for k, v in data.items():
3046+
observed_zarr_formats[v.zarr_format].append(k)
3047+
3048+
if len(observed_zarr_formats[2]) > 0 and len(observed_zarr_formats[3]) > 0:
3049+
msg = (
3050+
"Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. "
3051+
f"The following keys map to Zarr v2 nodes: {observed_zarr_formats.get(2)}. "
3052+
f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}."
3053+
"Ensure that all nodes have the same Zarr format."
3054+
)
3055+
3056+
raise ValueError(msg)
3057+
29613058
for k, v in data.items():
3059+
# TODO: ensure that the key is a valid path
29623060
# Split the key into its path components
29633061
key_split = k.split("/")
2964-
# Iterate over the path components
2965-
for subpath in accumulate(key_split, lambda a, b: f"{a}/{b}"):
3062+
3063+
# Iterate over the intermediate path components
3064+
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
3065+
for subpath in subpaths:
29663066
# If a component is not already in the output dict, add it
29673067
if subpath not in out:
29683068
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
3069+
else:
3070+
if not isinstance(out[subpath], GroupMetadata):
3071+
msg = (
3072+
f"The node at {subpath} contains other nodes, but it is not a Zarr group. "
3073+
"This is invalid. Only Zarr groups can contain other nodes."
3074+
)
3075+
raise ValueError(msg)
3076+
29693077
return out
29703078

29713079

@@ -3155,3 +3263,24 @@ def _build_node_v2(
31553263
return AsyncGroup(metadata, store_path=store_path)
31563264
case _:
31573265
raise ValueError(f"Unexpected metadata type: {type(metadata)}")
3266+
3267+
3268+
async def _set_return_key(store: Store, key: str, value: Buffer) -> str:
3269+
"""
3270+
Store.set, but the key and the value are returned.
3271+
Useful when saving metadata via asyncio.as_completed, because
3272+
we need to know which key was saved.
3273+
"""
3274+
await store.set(key, value)
3275+
return key
3276+
3277+
3278+
def _prepare_save_metadata(
3279+
store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata
3280+
) -> tuple[Coroutine[None, None, str], ...]:
3281+
"""
3282+
Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited.
3283+
"""
3284+
3285+
to_save = metadata.to_buffer_dict(default_buffer_prototype())
3286+
return tuple(_set_return_key(store, f"{path}/{key}", value) for key, value in to_save.items())

0 commit comments

Comments
 (0)