diff --git a/changes/2665.feature.rst b/changes/2665.feature.rst new file mode 100644 index 0000000000..40bec542ce --- /dev/null +++ b/changes/2665.feature.rst @@ -0,0 +1 @@ +Adds functions for concurrently creating multiple arrays and groups. \ No newline at end of file diff --git a/docs/quickstart.rst b/docs/quickstart.rst index d520554593..66bdae2a2e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -119,6 +119,28 @@ Zarr allows you to create hierarchical groups, similar to directories:: This creates a group with two datasets: ``foo`` and ``bar``. +Batch Hierarchy Creation +~~~~~~~~~~~~~~~~~~~~~~~~ + +Zarr provides tools for creating a collection of arrays and groups with a single function call. +Suppose we want to copy existing groups and arrays into a new storage backend: + + >>> # Create nested groups and add arrays + >>> root = zarr.group("data/example-3.zarr", attributes={'name': 'root'}) + >>> foo = root.create_group(name="foo") + >>> bar = root.create_array( + ... name="bar", shape=(100, 10), chunks=(10, 10), dtype="f4" + ... ) + >>> nodes = {'': root.metadata} | {k: v.metadata for k,v in root.members()} + >>> print(nodes) + >>> from zarr.storage import MemoryStore + >>> new_nodes = dict(zarr.create_hierarchy(store=MemoryStore(), nodes=nodes)) + >>> new_root = new_nodes[''] + >>> assert new_root.attrs == root.attrs + +Note that :func:`zarr.create_hierarchy` will only initialize arrays and groups -- copying array data must +be done in a separate step. + Persistent Storage ------------------ diff --git a/docs/user-guide/groups.rst b/docs/user-guide/groups.rst index 1e72df3478..4268004f70 100644 --- a/docs/user-guide/groups.rst +++ b/docs/user-guide/groups.rst @@ -75,6 +75,31 @@ For more information on groups see the :class:`zarr.Group` API docs. .. _user-guide-diagnostics: +Batch Group Creation +-------------------- + +You can also create multiple groups concurrently with a single function call. :func:`zarr.create_hierarchy` takes +a :class:`zarr.storage.Store` instance and a dict of ``key : metadata`` pairs, parses that dict, and +writes metadata documents to storage: + + >>> from zarr import create_hierarchy + >>> from zarr.core.group import GroupMetadata + >>> from zarr.storage import LocalStore + >>> node_spec = {'a/b/c': GroupMetadata()} + >>> nodes_created = dict(create_hierarchy(store=LocalStore(root='data'), nodes=node_spec)) + >>> print(sorted(nodes_created.items(), key=lambda kv: len(kv[0]))) + [('', ), ('a', ), ('a/b', ), ('a/b/c', )] + +Note that we only specified a single group named ``a/b/c``, but 4 groups were created. These additional groups +were created to ensure that the desired node ``a/b/c`` is connected to the root group ``''`` by a sequence +of intermediate groups. :func:`zarr.create_hierarchy` normalizes the ``nodes`` keyword argument to +ensure that the resulting hierarchy is complete, i.e. all groups or arrays are connected to the root +of the hierarchy via intermediate groups. + +Because :func:`zarr.create_hierarchy` concurrently creates metadata documents, it's more efficient +than repeated calls to :func:`create_group` or :func:`create_array`, provided you can statically define +the metadata for the groups and arrays you want to create. + Array and group diagnostics --------------------------- diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index bcbdaf7c19..4ffa4c9bbc 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -8,6 +8,7 @@ create, create_array, create_group, + create_hierarchy, empty, empty_like, full, @@ -50,6 +51,7 @@ "create", "create_array", "create_group", + "create_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 3a3d03bb71..6059893920 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -23,7 +23,12 @@ _warn_write_empty_chunks_kwarg, parse_dtype, ) -from zarr.core.group import AsyncGroup, ConsolidatedMetadata, GroupMetadata +from zarr.core.group import ( + AsyncGroup, + ConsolidatedMetadata, + GroupMetadata, + create_hierarchy, +) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v2 import _default_compressor, _default_filters from zarr.errors import NodeTypeValidationError @@ -48,6 +53,7 @@ "copy_store", "create", "create_array", + "create_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index e1f92633cd..9424ae1fde 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -10,6 +10,7 @@ from zarr.core.array import Array, AsyncArray from zarr.core.group import Group from zarr.core.sync import sync +from zarr.core.sync_group import create_hierarchy if TYPE_CHECKING: from collections.abc import Iterable @@ -46,6 +47,7 @@ "copy_store", "create", "create_array", + "create_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 1f5d57c0ab..a7f8a6c022 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -6,7 +6,9 @@ import logging import warnings from collections import defaultdict +from collections.abc import Iterator, Mapping from dataclasses import asdict, dataclass, field, fields, replace +from itertools import accumulate from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload import numpy as np @@ -49,12 +51,19 @@ from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder from zarr.core.sync import SyncMixin, sync -from zarr.errors import MetadataValidationError +from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path +from zarr.storage._utils import _join_paths, _normalize_path_keys, normalize_path if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator, Iterable, Iterator + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Coroutine, + Generator, + Iterable, + ) from typing import Any from zarr.core.array_spec import ArrayConfig, ArrayConfigLike @@ -407,6 +416,15 @@ def to_dict(self) -> dict[str, Any]: return result +@dataclass(frozen=True) +class ImplicitGroupMarker(GroupMetadata): + """ + Marker for an implicit group. Instances of this class are only used in the context of group + creation as a placeholder to represent groups that should only be created if they do not + already exist in storage + """ + + @dataclass(frozen=True) class AsyncGroup: """ @@ -416,6 +434,9 @@ class AsyncGroup: metadata: GroupMetadata store_path: StorePath + # TODO: make this correct and work + # TODO: ensure that this can be bound properly to subclass of AsyncGroup + @classmethod async def from_store( cls, @@ -662,55 +683,16 @@ async def getitem( """ store_path = self.store_path / key logger.debug("key=%s, store_path=%s", key, store_path) - metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) - - # Note: - # in zarr-python v2, we first check if `key` references an Array, else if `key` references - # a group,using standalone `contains_array` and `contains_group` functions. These functions - # are reusable, but for v3 they would perform redundant I/O operations. - # Not clear how much of that strategy we want to keep here. - elif self.metadata.zarr_format == 3: - zarr_json_bytes = await (store_path / ZARR_JSON).get() - if zarr_json_bytes is None: - raise KeyError(key) - else: - zarr_json = json.loads(zarr_json_bytes.to_bytes()) - metadata = _build_metadata_v3(zarr_json) - return _build_node_v3(metadata, store_path) - - elif self.metadata.zarr_format == 2: - # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? - # This guarantees that we will always make at least one extra request to the store - zgroup_bytes, zarray_bytes, zattrs_bytes = await asyncio.gather( - (store_path / ZGROUP_JSON).get(), - (store_path / ZARRAY_JSON).get(), - (store_path / ZATTRS_JSON).get(), + try: + return await get_node( + store=store_path.store, path=store_path.path, zarr_format=self.metadata.zarr_format ) - - if zgroup_bytes is None and zarray_bytes is None: - raise KeyError(key) - - # unpack the zarray, if this is None then we must be opening a group - zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None - zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None - # unpack the zattrs, this can be None if no attrs were written - zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} - - if zarray is not None: - metadata = _build_metadata_v2(zarray, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) - else: - # this is just for mypy - if TYPE_CHECKING: - assert zgroup is not None - metadata = _build_metadata_v2(zgroup, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) - else: - raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") + except FileNotFoundError as e: + raise KeyError(key) from e def _getitem_consolidated( self, store_path: StorePath, key: str, prefix: str @@ -1407,6 +1389,84 @@ async def _members( ): yield member + async def create_hierarchy( + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + *, + overwrite: bool = False, + ) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] + ]: + """ + Create a hierarchy of arrays or groups rooted at this group. + + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```. + + Explicitly specifying a root group, e.g. with ``nodes = {'': GroupMetadata()}`` is an error + because this group instance is the root group. + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. + + Parameters + ---------- + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the path of the group. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` as the parent group -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. + + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields + ------- + tuple[str, AsyncArray | AsyncGroup]. + """ + # check that all the nodes have the same zarr_format as Self + prefix = self.path + nodes_parsed = {} + for key, value in nodes.items(): + if value.zarr_format != self.metadata.zarr_format: + msg = ( + "The zarr_format of the nodes must be the same as the parent group. " + f"The node at {key} has zarr_format {value.zarr_format}, but the parent group" + f" has zarr_format {self.metadata.zarr_format}." + ) + raise ValueError(msg) + if normalize_path(key) == "": + msg = ( + "The input defines a root node, but a root node already exists, namely this Group instance." + "It is an error to use this method to create a root node. " + "Remove the root node from the input dict, or use a function like " + "create_rooted_hierarchy to create a rooted hierarchy." + ) + raise ValueError(msg) + else: + nodes_parsed[_join_paths([prefix, key])] = value + + async for key, node in create_hierarchy( + store=self.store, + nodes=nodes_parsed, + overwrite=overwrite, + ): + if prefix == "": + out_key = key + else: + out_key = key.removeprefix(prefix + "/") + yield out_key, node + async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" async for key, _ in self.members(): @@ -2030,6 +2090,66 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], return tuple((kv[0], _parse_async_node(kv[1])) for kv in _members) + def create_hierarchy( + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + *, + overwrite: bool = False, + ) -> Iterator[tuple[str, Group | Array]]: + """ + Create a hierarchy of arrays or groups rooted at this group. + + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```. + + Explicitly specifying a root group, e.g. with ``nodes = {'': GroupMetadata()}`` is an error + because this group instance is the root group. + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. + + Parameters + ---------- + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the path of the group. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` as the parent group -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. + + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields + ------- + tuple[str, Array | Group]. + + Examples + -------- + >>> import zarr + >>> from zarr.core.group import GroupMetadata + >>> root = zarr.create_group(store={}) + >>> for key, val in root.create_hierarchy({'a/b/c': GroupMetadata()}): + ... print(key, val) + ... + + + + """ + for key, node in self._sync_iter( + self._async_group.create_hierarchy(nodes, overwrite=overwrite) + ): + yield (key, _parse_async_node(node)) + def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2774,11 +2894,361 @@ def array( ) +async def create_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] +]: + """ + Create a complete zarr hierarchy from a collection of metadata objects. + + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}``` + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the root of the ``Store``. The root of the store can be specified with the empty + string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. + + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields + ------ + tuple[str, AsyncGroup | AsyncArray] + This function yields (path, node) pairs, in the order the nodes were created. + + Examples + -------- + >>> from zarr.api.asynchronous import create_hierarchy + >>> from zarr.storage import MemoryStore + >>> from zarr.core.group import GroupMetadata + >>> import asyncio + >>> store = MemoryStore() + >>> nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + >>> async def run(): + ... print(dict([x async for x in create_hierarchy(store=store, nodes=nodes)])) + >>> asyncio.run(run()) + # {'a': , '': } + """ + # normalize the keys to be valid paths + nodes_normed_keys = _normalize_path_keys(nodes) + + # ensure that all nodes have the same zarr_format, and add implicit groups as needed + nodes_parsed = _parse_hierarchy_dict(data=nodes_normed_keys) + redundant_implicit_groups = [] + + # empty hierarchies should be a no-op + if len(nodes_parsed) > 0: + # figure out which zarr format we are using + zarr_format = next(iter(nodes_parsed.values())).zarr_format + + # check which implicit groups will require materialization + implicit_group_keys = tuple( + filter(lambda k: isinstance(nodes_parsed[k], ImplicitGroupMarker), nodes_parsed) + ) + # read potential group metadata for each implicit group + maybe_extant_group_coros = ( + _read_group_metadata(store, k, zarr_format=zarr_format) for k in implicit_group_keys + ) + maybe_extant_groups = await asyncio.gather( + *maybe_extant_group_coros, return_exceptions=True + ) + + for key, value in zip(implicit_group_keys, maybe_extant_groups, strict=True): + if isinstance(value, BaseException): + if isinstance(value, FileNotFoundError): + # this is fine -- there was no group there, so we will create one + pass + else: + raise value + else: + # a loop exists already at ``key``, so we can avoid creating anything there + redundant_implicit_groups.append(key) + + if overwrite: + # we will remove any nodes that collide with arrays and non-implicit groups defined in + # nodes + + # track the keys of nodes we need to delete + to_delete_keys = [] + to_delete_keys.extend( + [k for k, v in nodes_parsed.items() if k not in implicit_group_keys] + ) + await asyncio.gather(*(store.delete_dir(key) for key in to_delete_keys)) + else: + # This type is long. + coros: ( + Generator[Coroutine[Any, Any, ArrayV2Metadata | GroupMetadata], None, None] + | Generator[Coroutine[Any, Any, ArrayV3Metadata | GroupMetadata], None, None] + ) + if zarr_format == 2: + coros = (_read_metadata_v2(store=store, path=key) for key in nodes_parsed) + elif zarr_format == 3: + coros = (_read_metadata_v3(store=store, path=key) for key in nodes_parsed) + else: # pragma: no cover + raise ValueError(f"Invalid zarr_format: {zarr_format}") # pragma: no cover + + extant_node_query = dict( + zip( + nodes_parsed.keys(), + await asyncio.gather(*coros, return_exceptions=True), + strict=False, + ) + ) + # iterate over the existing arrays / groups and figure out which of them conflict + # with the arrays / groups we want to create + for key, extant_node in extant_node_query.items(): + proposed_node = nodes_parsed[key] + if isinstance(extant_node, BaseException): + if isinstance(extant_node, FileNotFoundError): + # ignore FileNotFoundError, because they represent nodes we can safely create + pass + else: + # Any other exception is a real error + raise extant_node + else: + # this is a node that already exists, but a node with the same key was specified + # in nodes_parsed. + if isinstance(extant_node, GroupMetadata): + # a group already exists where we want to create a group + if isinstance(proposed_node, ImplicitGroupMarker): + # we have proposed an implicit group, which is OK -- we will just skip + # creating this particular metadata document + redundant_implicit_groups.append(key) + else: + # we have proposed an explicit group, which is an error, given that a + # group already exists. + raise ContainsGroupError(store, key) + elif isinstance(extant_node, ArrayV2Metadata | ArrayV3Metadata): + # we are trying to overwrite an existing array. this is an error. + raise ContainsArrayError(store, key) + + nodes_explicit: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} + + for k, v in nodes_parsed.items(): + if k not in redundant_implicit_groups: + if isinstance(v, ImplicitGroupMarker): + nodes_explicit[k] = GroupMetadata(zarr_format=v.zarr_format) + else: + nodes_explicit[k] = v + + async for key, node in create_nodes(store=store, nodes=nodes_explicit): + yield key, node + + +async def create_nodes( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] +]: + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + + Yields + ------ + AsyncGroup | AsyncArray + The created nodes in the order they are created. + """ + + # Note: the only way to alter this value is via the config. If that's undesirable for some reason, + # then we should consider adding a keyword argument this this function + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + create_tasks: list[Coroutine[None, None, str]] = [] + + for key, value in nodes.items(): + # make the key absolute + create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) + + created_object_keys = [] + + for coro in asyncio.as_completed(create_tasks): + created_key = await coro + # we need this to track which metadata documents were written so that we can yield a + # complete v2 Array / Group class after both .zattrs and the metadata JSON was created. + created_object_keys.append(created_key) + + # get the node name from the object key + if len(created_key.split("/")) == 1: + # this is the root node + meta_out = nodes[""] + node_name = "" + else: + # turn "foo/" into "foo" + node_name = created_key[: created_key.rfind("/")] + meta_out = nodes[node_name] + if meta_out.zarr_format == 3: + yield node_name, _build_node(store=store, path=node_name, metadata=meta_out) + else: + # For zarr v2 + # we only want to yield when both the metadata and attributes are created + # so we track which keys have been created, and wait for both the meta key and + # the attrs key to be created before yielding back the AsyncArray / AsyncGroup + + attrs_done = _join_paths([node_name, ZATTRS_JSON]) in created_object_keys + + if isinstance(meta_out, GroupMetadata): + meta_done = _join_paths([node_name, ZGROUP_JSON]) in created_object_keys + else: + meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys + + if meta_done and attrs_done: + yield node_name, _build_node(store=store, path=node_name, metadata=meta_out) + + continue + + +def _get_roots( + data: Iterable[str], +) -> tuple[str, ...]: + """ + Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of + path segments. + """ + if "" in data: + return ("",) + keys_split = sorted((key.split("/") for key in data), key=len) + groups: defaultdict[int, list[str]] = defaultdict(list) + for key_split in keys_split: + groups[len(key_split)].append("/".join(key_split)) + return tuple(groups[min(groups.keys())]) + + +def _parse_hierarchy_dict( + *, + data: Mapping[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: + """ + Take an input with type Mapping[str, ArrayMetadata | GroupMetadata] and parse it into + a dict of str: node pairs that models a valid, complete Zarr hierarchy. + + If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups, + then return a dict with the exact same data as the input. + + Otherwise, return a dict derived from the input with GroupMetadata inserted as needed to make + the hierarchy complete. + + For example, an input of {'a/b': ArrayMetadata} is incomplete, because it references two + groups (the root group '' and a group at 'a') that are not specified in the input. Applying this function + to that input will result in a return value of + {'': GroupMetadata, 'a': GroupMetadata, 'a/b': ArrayMetadata}, i.e. the implied groups + were added. + + The input is also checked for the following conditions; an error is raised if any are violated: + + - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes). + - All arrays and groups must have the same ``zarr_format`` value. + + This function ensures that the input is transformed into a specification of a complete and valid + Zarr hierarchy. + """ + + # ensure that all nodes have the same zarr format + data_purified = _ensure_consistent_zarr_format(data) + + # ensure that keys are normalized to zarr paths + data_normed_keys = _normalize_path_keys(data_purified) + + # insert an implicit root group if a root was not specified + # but not if an empty dict was provided, because any empty hierarchy has no nodes + if len(data_normed_keys) > 0 and "" not in data_normed_keys: + z_format = next(iter(data_normed_keys.values())).zarr_format + data_normed_keys = data_normed_keys | {"": ImplicitGroupMarker(zarr_format=z_format)} + + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed_keys} + + for k, v in data_normed_keys.items(): + key_split = k.split("/") + + # get every parent path + *subpaths, _ = accumulate(key_split, lambda a, b: _join_paths([a, b])) + + for subpath in subpaths: + # If a component is not already in the output dict, add ImplicitGroupMetadata + if subpath not in out: + out[subpath] = ImplicitGroupMarker(zarr_format=v.zarr_format) + else: + if not isinstance(out[subpath], GroupMetadata | ImplicitGroupMarker): + msg = ( + f"The node at {subpath} contains other nodes, but it is not a Zarr group. " + "This is invalid. Only Zarr groups can contain other nodes." + ) + raise ValueError(msg) + return out + + +def _ensure_consistent_zarr_format( + data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> Mapping[str, GroupMetadata | ArrayV2Metadata] | Mapping[str, GroupMetadata | ArrayV3Metadata]: + """ + Ensure that all values of the input dict have the same zarr format. If any do not, + then a value error is raised. + """ + observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []} + + for k, v in data.items(): + observed_zarr_formats[v.zarr_format].append(k) + + if len(observed_zarr_formats[2]) > 0 and len(observed_zarr_formats[3]) > 0: + msg = ( + "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " + f"The following keys map to Zarr v2 nodes: {observed_zarr_formats.get(2)}. " + f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." + "Ensure that all nodes have the same Zarr format." + ) + raise ValueError(msg) + + return cast( + Mapping[str, GroupMetadata | ArrayV2Metadata] + | Mapping[str, GroupMetadata | ArrayV3Metadata], + data, + ) + + async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: """ - Combine node.getitem with an optional semaphore. If the semaphore parameter is an + Wrap Group.getitem with an optional semaphore. + + If the semaphore parameter is an asyncio.Semaphore instance, then the getitem operation is performed inside an async context manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked without a context manager. @@ -2892,71 +3362,283 @@ async def _iter_members_deep( yield key, node -def _resolve_metadata_v2( - blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], -) -> ArrayV2Metadata | GroupMetadata: - zarr_metadata = json.loads(blobs[0]) - attrs = json.loads(blobs[1]) - if "shape" in zarr_metadata: - return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs}) +async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / zarr.json. If no such document is found, raise a + FileNotFoundError. + """ + zarr_json_bytes = await store.get( + _join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype() + ) + if zarr_json_bytes is None: + raise FileNotFoundError(path) + else: + zarr_json = json.loads(zarr_json_bytes.to_bytes()) + return _build_metadata_v3(zarr_json) + + +async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / (.zgroup | .zarray). If no such document is found, + raise a FileNotFoundError. + """ + # TODO: consider first fetching array metadata, and only fetching group metadata when we don't + # find an array + zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( + store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()), + ) + + if zattrs_bytes is None: + zattrs = {} + else: + zattrs = json.loads(zattrs_bytes.to_bytes()) + + # TODO: decide how to handle finding both array and group metadata. The spec does not seem to + # consider this situation. A practical approach would be to ignore that combination, and only + # return the array metadata. + if zarray_bytes is not None: + zmeta = json.loads(zarray_bytes.to_bytes()) else: - return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) + if zgroup_bytes is None: + # neither .zarray or .zgroup were found results in KeyError + raise FileNotFoundError(path) + else: + zmeta = json.loads(zgroup_bytes.to_bytes()) + + return _build_metadata_v2(zmeta, zattrs) + + +async def _read_group_metadata_v2(store: Store, path: str) -> GroupMetadata: + """ + Read group metadata or error + """ + meta = await _read_metadata_v2(store=store, path=path) + if not isinstance(meta, GroupMetadata): + raise FileNotFoundError(f"Group metadata was not found in {store} at {path}") + return meta + + +async def _read_group_metadata_v3(store: Store, path: str) -> GroupMetadata: + """ + Read group metadata or error + """ + meta = await _read_metadata_v3(store=store, path=path) + if not isinstance(meta, GroupMetadata): + raise FileNotFoundError(f"Group metadata was not found in {store} at {path}") + return meta -def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: +async def _read_group_metadata( + store: Store, path: str, *, zarr_format: ZarrFormat +) -> GroupMetadata: + if zarr_format == 2: + return await _read_group_metadata_v2(store=store, path=path) + return await _read_group_metadata_v3(store=store, path=path) + + +def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V3 metadata into the corresponding metadata class. """ if "node_type" not in zarr_json: - raise KeyError("missing `node_type` key in metadata document.") + raise MetadataValidationError("node_type", "array or group", "nothing (the key is missing)") match zarr_json: case {"node_type": "array"}: return ArrayV3Metadata.from_dict(zarr_json) case {"node_type": "group"}: return GroupMetadata.from_dict(zarr_json) - case _: - raise ValueError("invalid value for `node_type` key in metadata document") + case _: # pragma: no cover + raise ValueError( + "invalid value for `node_type` key in metadata document" + ) # pragma: no cover def _build_metadata_v2( - zarr_json: dict[str, Any], attrs_json: dict[str, Any] + zarr_json: dict[str, object], attrs_json: dict[str, JSON] ) -> ArrayV2Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V2 metadata into the corresponding metadata class. """ match zarr_json: case {"shape": _}: return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) - case _: + case _: # pragma: no cover return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) -def _build_node_v3( - metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath -) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: +@overload +def _build_node( + *, store: Store, path: str, metadata: ArrayV2Metadata +) -> AsyncArray[ArrayV2Metadata]: ... + + +@overload +def _build_node( + *, store: Store, path: str, metadata: ArrayV3Metadata +) -> AsyncArray[ArrayV3Metadata]: ... + + +@overload +def _build_node(*, store: Store, path: str, metadata: GroupMetadata) -> AsyncGroup: ... + + +def _build_node( + *, store: Store, path: str, metadata: ArrayV3Metadata | ArrayV2Metadata | GroupMetadata +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Take a metadata object and return a node (AsyncArray or AsyncGroup). """ + store_path = StorePath(store=store, path=path) match metadata: - case ArrayV3Metadata(): + case ArrayV2Metadata() | ArrayV3Metadata(): return AsyncArray(metadata, store_path=store_path) case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) - case _: - raise ValueError(f"Unexpected metadata type: {type(metadata)}") + case _: # pragma: no cover + raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover -def _build_node_v2( - metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath -) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: +async def _get_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: """ - Take a metadata object and return a node (AsyncArray or AsyncGroup). + Read a Zarr v2 AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + + Returns + ------- + AsyncArray | AsyncGroup """ + metadata = await _read_metadata_v2(store=store, path=path) + return _build_node(store=store, path=path, metadata=metadata) - match metadata: - case ArrayV2Metadata(): - return AsyncArray(metadata, store_path=store_path) - case GroupMetadata(): - return AsyncGroup(metadata, store_path=store_path) - case _: - raise ValueError(f"Unexpected metadata type: {type(metadata)}") + +async def _get_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Read a Zarr v3 AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + + Returns + ------- + AsyncArray | AsyncGroup + """ + metadata = await _read_metadata_v3(store=store, path=path) + return _build_node(store=store, path=path, metadata=metadata) + + +async def get_node( + store: Store, path: str, zarr_format: ZarrFormat +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Get an AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + zarr_format : {2, 3} + The zarr format of the node to read. + + Returns + ------- + AsyncArray | AsyncGroup + """ + + match zarr_format: + case 2: + return await _get_node_v2(store=store, path=path) + case 3: + return await _get_node_v3(store=store, path=path) + case _: # pragma: no cover + raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover + + +async def _set_return_key( + *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None +) -> str: + """ + Write a value to storage at the given key. The key is returned. + Useful when saving values via routines that return results in execution order, + like asyncio.as_completed, because in this case we need to know which key was saved in order + to yield the right object to the caller. + + Parameters + ---------- + store : Store + The store to save the value to. + key : str + The key to save the value to. + value : Buffer + The value to save. + semaphore : asyncio.Semaphore | None + An optional semaphore to use to limit the number of concurrent writes. + """ + + if semaphore is not None: + async with semaphore: + await store.set(key, value) + else: + await store.set(key, value) + return key + + +def _persist_metadata( + store: Store, + path: str, + metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, + semaphore: asyncio.Semaphore | None = None, +) -> tuple[Coroutine[None, None, str], ...]: + """ + Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. + """ + + to_save = metadata.to_buffer_dict(default_buffer_prototype()) + return tuple( + _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) + for key, value in to_save.items() + ) + + +async def create_rooted_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: + """ + Create an ``AsyncGroup`` or ``AsyncArray`` from a store and a dict of metadata documents. + This function ensures that its input contains a specification of a root node, + calls ``create_hierarchy`` to create nodes, and returns the root node of the hierarchy. + """ + roots = _get_roots(nodes.keys()) + if len(roots) != 1: + msg = ( + "The input does not specify a root node. " + "This function can only create hierarchies that contain a root node, which is " + "defined as a group that is ancestral to all the other arrays and " + "groups in the hierarchy, or a single array." + ) + raise ValueError(msg) + else: + root_key = roots[0] + + nodes_created = [ + x async for x in create_hierarchy(store=store, nodes=nodes, overwrite=overwrite) + ] + return dict(nodes_created)[root_key] diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index 2bb5f24802..d9b4839e8e 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -6,14 +6,14 @@ import os import threading from concurrent.futures import ThreadPoolExecutor, wait -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import ParamSpec from zarr.core.config import config if TYPE_CHECKING: - from collections.abc import AsyncIterator, Coroutine + from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine from typing import Any logger = logging.getLogger(__name__) @@ -215,3 +215,17 @@ async def iter_to_list() -> list[T]: return [item async for item in async_iterator] return self._sync(iter_to_list()) + + +async def _with_semaphore( + func: Callable[[], Awaitable[T]], semaphore: asyncio.Semaphore | None = None +) -> T: + """ + Await the result of invoking the no-argument-callable ``func`` within the context manager + provided by a Semaphore, if one is provided. Otherwise, just await the result of invoking + ``func``. + """ + if semaphore is None: + return await func() + async with semaphore: + return await func() diff --git a/src/zarr/core/sync_group.py b/src/zarr/core/sync_group.py new file mode 100644 index 0000000000..39d8a17992 --- /dev/null +++ b/src/zarr/core/sync_group.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from zarr.core.group import Group, GroupMetadata, _parse_async_node +from zarr.core.group import create_hierarchy as create_hierarchy_async +from zarr.core.group import create_nodes as create_nodes_async +from zarr.core.group import create_rooted_hierarchy as create_rooted_hierarchy_async +from zarr.core.group import get_node as get_node_async +from zarr.core.sync import _collect_aiterator, sync + +if TYPE_CHECKING: + from collections.abc import Iterator + + from zarr.abc.store import Store + from zarr.core.array import Array + from zarr.core.common import ZarrFormat + from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata + + +def create_nodes( + *, store: Store, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] +) -> Iterator[tuple[str, Group | Array]]: + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + + Yields + ------ + Group | Array + The created nodes. + """ + coro = create_nodes_async(store=store, nodes=nodes) + + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) + + +def create_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Iterator[tuple[str, Group | Array]]: + """ + Create a complete zarr hierarchy from a collection of metadata objects. + + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}``` + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the root of the ``Store``. The root of the store can be specified with the empty + string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. + + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields + ------ + tuple[str, Group | Array] + This function yields (path, node) pairs, in the order the nodes were created. + + Examples + -------- + >>> from zarr import create_hierarchy + >>> from zarr.storage import MemoryStore + >>> from zarr.core.group import GroupMetadata + + >>> store = MemoryStore() + >>> nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + >>> nodes_created = dict(create_hierarchy(store=store, nodes=nodes)) + >>> print(nodes) + # {'a': GroupMetadata(attributes={'name': 'leaf'}, zarr_format=3, consolidated_metadata=None, node_type='group')} + """ + coro = create_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite) + + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) + + +def create_rooted_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Group | Array: + """ + Create a Zarr hierarchy with a root, and return the root node, which could be a ``Group`` + or ``Array`` instance. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + overwrite : bool + Whether to overwrite existing nodes. Default is ``False``. + + Returns + ------- + Group | Array + """ + async_node = sync(create_rooted_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite)) + return _parse_async_node(async_node) + + +def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: + """ + Get an Array or Group from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + zarr_format : {2, 3} + The zarr format of the node to read. + + Returns + ------- + Array | Group + """ + + return _parse_async_node(sync(get_node_async(store=store, path=path, zarr_format=zarr_format))) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 4fc3171eb8..eda4342f47 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -2,11 +2,13 @@ import re from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer @@ -66,3 +68,45 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> else: raise ValueError(f"Unexpected byte_range, got {byte_range}.") return (start, stop) + + +def _join_paths(paths: Iterable[str]) -> str: + """ + Filter out instances of '' and join the remaining strings with '/'. + + Because the root node of a zarr hierarchy is represented by an empty string, + """ + return "/".join(filter(lambda v: v != "", paths)) + + +def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: + """ + Normalize the input paths according to the normalization scheme used for zarr node paths. + If any two paths normalize to the same value, raise a ValueError. + """ + path_map: dict[str, str] = {} + for path in paths: + parsed = normalize_path(path) + if parsed in path_map: + msg = ( + f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. " + f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. " + f"You should use either '{path}' or '{path_map[parsed]}', but not both." + ) + raise ValueError(msg) + path_map[parsed] = path + return tuple(path_map.keys()) + + +T = TypeVar("T") + + +def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: + """ + Normalize the keys of the input dict according to the normalization scheme used for zarr node + paths. If any two keys in the input normalize to the same value, raise a ValueError. + Returns a dict where the keys are the elements of the input and the values are the + normalized form of each key. + """ + parsed_keys = _normalize_paths(data.keys()) + return dict(zip(parsed_keys, data.values(), strict=True)) diff --git a/tests/conftest.py b/tests/conftest.py index 9be675cb20..04034cb5b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,16 +11,30 @@ from zarr import AsyncGroup, config from zarr.abc.store import Store +from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation +from zarr.core.array import ( + _parse_chunk_encoding_v2, + _parse_chunk_encoding_v3, + _parse_chunk_key_encoding, +) +from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition +from zarr.core.common import JSON, parse_dtype, parse_shapelike +from zarr.core.config import config as zarr_config +from zarr.core.metadata.v2 import ArrayV2Metadata +from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import FsspecStore, LocalStore, MemoryStore, StorePath, ZipStore if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Iterable from typing import Any, Literal from _pytest.compat import LEGACY_PATH - from zarr.core.common import ChunkCoords, MemoryOrder, ZarrFormat + from zarr.abc.codec import Codec + from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike + from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike + from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat async def parse_store( @@ -177,3 +191,212 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow], verbosity=Verbosity.verbose, ) + +# TODO: uncomment these overrides when we can get mypy to accept them +""" +@overload +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, + chunks: ChunkCoords | Literal["auto"], + shards: None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: MemoryOrder | None, + zarr_format: Literal[2], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: None, +) -> ArrayV2Metadata: ... + + +@overload +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, + chunks: ChunkCoords | Literal["auto"], + shards: ShardsLike | None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: None, + zarr_format: Literal[3], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV3Metadata: ... +""" + + +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, + chunks: ChunkCoords | Literal["auto"] = "auto", + shards: ShardsLike | None = None, + filters: FiltersLike = "auto", + compressors: CompressorsLike = "auto", + serializer: SerializerLike = "auto", + fill_value: Any | None = None, + order: MemoryOrder | None = None, + zarr_format: ZarrFormat, + attributes: dict[str, JSON] | None = None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, + dimension_names: Iterable[str] | None = None, +) -> ArrayV2Metadata | ArrayV3Metadata: + """ + Create array metadata + """ + dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format) + shape_parsed = parse_shapelike(shape) + chunk_key_encoding_parsed = _parse_chunk_key_encoding( + chunk_key_encoding, zarr_format=zarr_format + ) + + shard_shape_parsed, chunk_shape_parsed = _auto_partition( + array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed + ) + + if order is None: + order_parsed = zarr_config.get("array.order") + else: + order_parsed = order + chunks_out: tuple[int, ...] + + if zarr_format == 2: + filters_parsed, compressor_parsed = _parse_chunk_encoding_v2( + compressor=compressors, filters=filters, dtype=np.dtype(dtype) + ) + return ArrayV2Metadata( + shape=shape_parsed, + dtype=np.dtype(dtype), + chunks=chunk_shape_parsed, + order=order_parsed, + dimension_separator=chunk_key_encoding_parsed.separator, + fill_value=fill_value, + compressor=compressor_parsed, + filters=filters_parsed, + attributes=attributes, + ) + elif zarr_format == 3: + array_array, array_bytes, bytes_bytes = _parse_chunk_encoding_v3( + compressors=compressors, + filters=filters, + serializer=serializer, + dtype=dtype_parsed, + ) + + sub_codecs: tuple[Codec, ...] = (*array_array, array_bytes, *bytes_bytes) + codecs_out: tuple[Codec, ...] + if shard_shape_parsed is not None: + index_location = None + if isinstance(shards, dict): + index_location = ShardingCodecIndexLocation(shards.get("index_location", None)) + if index_location is None: + index_location = ShardingCodecIndexLocation.end + sharding_codec = ShardingCodec( + chunk_shape=chunk_shape_parsed, + codecs=sub_codecs, + index_location=index_location, + ) + sharding_codec.validate( + shape=chunk_shape_parsed, + dtype=dtype_parsed, + chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed), + ) + codecs_out = (sharding_codec,) + chunks_out = shard_shape_parsed + else: + chunks_out = chunk_shape_parsed + codecs_out = sub_codecs + + return ArrayV3Metadata( + shape=shape_parsed, + data_type=dtype_parsed, + chunk_grid=RegularChunkGrid(chunk_shape=chunks_out), + chunk_key_encoding=chunk_key_encoding_parsed, + fill_value=fill_value, + codecs=codecs_out, + attributes=attributes, + dimension_names=dimension_names, + ) + + raise ValueError(f"Invalid Zarr format: {zarr_format}") + + +# TODO: uncomment these overrides when we can get mypy to accept them +""" +@overload +def meta_from_array( + array: np.ndarray[Any, Any], + chunks: ChunkCoords | Literal["auto"], + shards: None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: MemoryOrder | None, + zarr_format: Literal[2], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV2Metadata: ... + + +@overload +def meta_from_array( + array: np.ndarray[Any, Any], + chunks: ChunkCoords | Literal["auto"], + shards: ShardsLike | None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: None, + zarr_format: Literal[3], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV3Metadata: ... + +""" + + +def meta_from_array( + array: np.ndarray[Any, Any], + *, + chunks: ChunkCoords | Literal["auto"] = "auto", + shards: ShardsLike | None = None, + filters: FiltersLike = "auto", + compressors: CompressorsLike = "auto", + serializer: SerializerLike = "auto", + fill_value: Any | None = None, + order: MemoryOrder | None = None, + zarr_format: ZarrFormat = 3, + attributes: dict[str, JSON] | None = None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, + dimension_names: Iterable[str] | None = None, +) -> ArrayV3Metadata | ArrayV2Metadata: + """ + Create array metadata from an array + """ + return create_array_metadata( + shape=array.shape, + dtype=array.dtype, + chunks=chunks, + shards=shards, + filters=filters, + compressors=compressors, + serializer=serializer, + fill_value=fill_value, + order=order, + zarr_format=zarr_format, + attributes=attributes, + chunk_key_encoding=chunk_key_encoding, + dimension_names=dimension_names, + ) diff --git a/tests/test_api.py b/tests/test_api.py index e9db33f6c5..3b565f8e60 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,13 @@ -import pathlib +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pathlib + + from zarr.abc.store import Store + from zarr.core.common import JSON, MemoryOrder, ZarrFormat + import warnings from typing import Literal @@ -8,9 +17,9 @@ import zarr import zarr.api.asynchronous +import zarr.api.synchronous import zarr.core.group from zarr import Array, Group -from zarr.abc.store import Store from zarr.api.synchronous import ( create, create_array, @@ -23,7 +32,6 @@ save_array, save_group, ) -from zarr.core.common import JSON, MemoryOrder, ZarrFormat from zarr.errors import MetadataValidationError from zarr.storage import MemoryStore from zarr.storage._utils import normalize_path @@ -1124,6 +1132,13 @@ def test_open_array_with_mode_r_plus(store: Store) -> None: z2[:] = 3 +def test_api_exports() -> None: + """ + Test that the sync API and the async API export the same objects + """ + assert zarr.api.asynchronous.__all__ == zarr.api.synchronous.__all__ + + @gpu_test @pytest.mark.parametrize( "store", diff --git a/tests/test_group.py b/tests/test_group.py index 144054605e..521819ea0e 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,8 +1,10 @@ from __future__ import annotations import contextlib +import inspect import operator import pickle +import re import time import warnings from typing import TYPE_CHECKING, Any, Literal @@ -17,18 +19,35 @@ import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store +from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.group import ConsolidatedMetadata, GroupMetadata -from zarr.core.sync import sync -from zarr.errors import ContainsArrayError, ContainsGroupError +from zarr.core.config import config as zarr_config +from zarr.core.group import ( + ConsolidatedMetadata, + GroupMetadata, + ImplicitGroupMarker, + _build_metadata_v3, + _get_roots, + _parse_hierarchy_dict, + create_hierarchy, + create_nodes, + create_rooted_hierarchy, + get_node, +) +from zarr.core.metadata.v3 import ArrayV3Metadata +from zarr.core.sync import _collect_aiterator, sync +from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path +from zarr.storage._utils import _join_paths, normalize_path from zarr.testing.store import LatencyStore -from .conftest import parse_store +from .conftest import meta_from_array, parse_store if TYPE_CHECKING: + from collections.abc import Callable + from _pytest.compat import LEGACY_PATH from zarr.core.common import JSON, ZarrFormat @@ -353,7 +372,7 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool ) with pytest.raises(KeyError): - # We've chosen to trust the consolidted metadata, which doesn't + # We've chosen to trust the consolidated metadata, which doesn't # contain this array group["subgroup/subarray"] @@ -1443,7 +1462,501 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None @pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_performance(store: MemoryStore) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_nodes( + impl: Literal["async", "sync"], store: Store, zarr_format: ZarrFormat +) -> None: + """ + Ensure that ``create_nodes`` can create a zarr hierarchy from a model of that + hierarchy in dict form. Note that this creates an incomplete Zarr hierarchy. + """ + node_spec = { + "group": GroupMetadata(attributes={"foo": 10}), + "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), + } + if impl == "sync": + observed_nodes = dict(sync_group.create_nodes(store=store, nodes=node_spec)) + elif impl == "async": + observed_nodes = dict(await _collect_aiterator(create_nodes(store=store, nodes=node_spec))) + else: + raise ValueError(f"Invalid impl: {impl}") + + assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: + """ + Test that the execution time of create_nodes can be constrained by the async concurrency + configuration setting. + """ + set_latency = 0.02 + num_groups = 10 + groups = {str(idx): GroupMetadata() for idx in range(num_groups)} + + latency_store = LatencyStore(store, set_latency=set_latency) + + # check how long it takes to iterate over the groups + # if create_nodes is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds + + with zarr_config.set({"async.concurrency": 1}): + start = time.time() + _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) + elapsed = time.time() - start + assert elapsed > num_groups * set_latency + + +@pytest.mark.parametrize( + ("a_func", "b_func"), + [ + (zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy), + (zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy), + (zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes), + (zarr.core.group.create_rooted_hierarchy, zarr.core.sync_group.create_rooted_hierarchy), + (zarr.core.group.get_node, zarr.core.sync_group.get_node), + ], +) +def test_consistent_signatures( + a_func: Callable[[object], object], b_func: Callable[[object], object] +) -> None: + """ + Ensure that pairs of functions have consistent signatures + """ + base_sig = inspect.signature(a_func) + test_sig = inspect.signature(b_func) + assert test_sig.parameters == base_sig.parameters + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("overwrite", [True, False]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy( + impl: Literal["async", "sync"], store: Store, overwrite: bool, zarr_format: ZarrFormat +) -> None: + """ + Test that ``create_hierarchy`` can create a complete Zarr hierarchy, even if the input describes + an incomplete one. + """ + + hierarchy_spec = { + "group": GroupMetadata(attributes={"path": "group"}, zarr_format=zarr_format), + "group/array_0": meta_from_array( + np.arange(3), attributes={"path": "group/array_0"}, zarr_format=zarr_format + ), + "group/subgroup/array_0": meta_from_array( + np.arange(4), attributes={"path": "group/subgroup/array_0"}, zarr_format=zarr_format + ), + } + pre_existing_nodes = { + "group/extra": GroupMetadata(zarr_format=zarr_format, attributes={"path": "group/extra"}), + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "root"}), + } + # we expect create_hierarchy to insert a group that was missing from the hierarchy spec + expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} + + # initialize the group with some nodes + _ = dict(sync_group.create_nodes(store=store, nodes=pre_existing_nodes)) + + if impl == "sync": + created = dict( + sync_group.create_hierarchy(store=store, nodes=hierarchy_spec, overwrite=overwrite) + ) + elif impl == "async": + created = dict( + [ + a + async for a in create_hierarchy( + store=store, nodes=hierarchy_spec, overwrite=overwrite + ) + ] + ) + else: + raise ValueError(f"Invalid impl: {impl}") + if not overwrite: + extra_group = sync_group.get_node(store=store, path="group/extra", zarr_format=zarr_format) + assert extra_group.metadata.attributes == {"path": "group/extra"} + else: + with pytest.raises(FileNotFoundError): + await get_node(store=store, path="group/extra", zarr_format=zarr_format) + assert expected_meta == {k: v.metadata for k, v in created.items()} + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("extant_node", ["array", "group"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy_existing_nodes( + impl: Literal["async", "sync"], + store: Store, + extant_node: Literal["array", "group"], + zarr_format: ZarrFormat, +) -> None: + """ + Test that create_hierarchy with overwrite = False will not overwrite an existing array or group, + and raises an exception instead. + """ + extant_node_path = "node" + + if extant_node == "array": + extant_metadata = meta_from_array( + np.zeros(4), zarr_format=zarr_format, attributes={"extant": True} + ) + new_metadata = meta_from_array(np.zeros(4), zarr_format=zarr_format) + err_cls = ContainsArrayError + else: + extant_metadata = GroupMetadata(zarr_format=zarr_format, attributes={"extant": True}) + new_metadata = GroupMetadata(zarr_format=zarr_format) + err_cls = ContainsGroupError + + # write the extant metadata + tuple(sync_group.create_nodes(store=store, nodes={extant_node_path: extant_metadata})) + + msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." + # ensure that we cannot invoke create_hierarchy with overwrite=False here + if impl == "sync": + with pytest.raises(err_cls, match=re.escape(msg)): + tuple( + sync_group.create_hierarchy( + store=store, nodes={"node": new_metadata}, overwrite=False + ) + ) + elif impl == "async": + with pytest.raises(err_cls, match=re.escape(msg)): + tuple( + [ + x + async for x in create_hierarchy( + store=store, nodes={"node": new_metadata}, overwrite=False + ) + ] + ) + else: + raise ValueError(f"Invalid impl: {impl}") + + # ensure that the extant metadata was not overwritten + assert ( + await get_node(store=store, path=extant_node_path, zarr_format=zarr_format) + ).metadata.attributes == {"extant": True} + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("overwrite", [True, False]) +@pytest.mark.parametrize("group_path", ["", "foo"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_group_create_hierarchy( + store: Store, + zarr_format: ZarrFormat, + overwrite: bool, + group_path: str, + impl: Literal["async", "sync"], +) -> None: + """ + Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. + Also test that off-target nodes are not deleted, and that the root group is not deleted + """ + root_attrs = {"root": True} + g = sync_group.create_rooted_hierarchy( + store=store, + nodes={group_path: GroupMetadata(zarr_format=zarr_format, attributes=root_attrs)}, + ) + node_spec = { + "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + "a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}), + "a/b/c": meta_from_array( + np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"} + ), + } + # This node should be kept if overwrite is True + extant_spec = {"b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "b"})} + if impl == "async": + extant_created = dict( + await _collect_aiterator(g._async_group.create_hierarchy(extant_spec, overwrite=False)) + ) + nodes_created = dict( + await _collect_aiterator( + g._async_group.create_hierarchy(node_spec, overwrite=overwrite) + ) + ) + elif impl == "sync": + extant_created = dict(g.create_hierarchy(extant_spec, overwrite=False)) + nodes_created = dict(g.create_hierarchy(node_spec, overwrite=overwrite)) + + all_members = dict(g.members(max_depth=None)) + for k, v in node_spec.items(): + assert all_members[k].metadata == v == nodes_created[k].metadata + + # if overwrite is True, the extant nodes should be erased + for k, v in extant_spec.items(): + if overwrite: + assert k in all_members + else: + assert all_members[k].metadata == v == extant_created[k].metadata + # ensure that we left the root group as-is + assert ( + sync_group.get_node(store=store, path=group_path, zarr_format=zarr_format).attrs.asdict() + == root_attrs + ) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("overwrite", [True, False]) +def test_group_create_hierarchy_no_root( + store: Store, zarr_format: ZarrFormat, overwrite: bool +) -> None: + """ + Test that the Group.create_hierarchy method will error if the dict provided contains a root. + """ + g = Group.from_store(store, zarr_format=zarr_format) + tree = { + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + } + with pytest.raises( + ValueError, match="It is an error to use this method to create a root node. " + ): + _ = dict(g.create_hierarchy(tree, overwrite=overwrite)) + + +class TestParseHierarchyDict: + """ + Tests for the function that parses dicts of str : Metadata pairs, ensuring that the output models a + valid Zarr hierarchy + """ + + @staticmethod + def test_normed_keys() -> None: + """ + Test that keys get normalized properly + """ + + nodes = { + "a": GroupMetadata(), + "/b": GroupMetadata(), + "": GroupMetadata(), + "/a//c////": GroupMetadata(), + } + observed = _parse_hierarchy_dict(data=nodes) + expected = {normalize_path(k): v for k, v in nodes.items()} + assert observed == expected + + @staticmethod + def test_empty() -> None: + """ + Test that an empty dict passes through + """ + assert _parse_hierarchy_dict(data={}) == {} + + @staticmethod + def test_implicit_groups() -> None: + """ + Test that implicit groups were added as needed. + """ + requested = {"a/b/c": GroupMetadata()} + expected = requested | { + "": ImplicitGroupMarker(), + "a": ImplicitGroupMarker(), + "a/b": ImplicitGroupMarker(), + } + observed = _parse_hierarchy_dict(data=requested) + assert observed == expected + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_create_hierarchy_invalid_mixed_zarr_format( + store: Store, zarr_format: ZarrFormat +) -> None: + """ + Test that ``Group.create_hierarchy`` will raise an error if the zarr_format of the nodes is + different from the parent group. + """ + other_format = 2 if zarr_format == 3 else 3 + g = Group.from_store(store, zarr_format=other_format) + tree = { + "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + "a/b": meta_from_array(np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/c"}), + } + + msg = "The zarr_format of the nodes must be the same as the parent group." + with pytest.raises(ValueError, match=msg): + _ = tuple(g.create_hierarchy(tree)) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("defect", ["array/array", "array/group"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy_invalid_nested( + impl: Literal["async", "sync"], store: Store, defect: tuple[str, str], zarr_format: ZarrFormat +) -> None: + """ + Test that create_hierarchy will not create a Zarr array that contains a Zarr group + or Zarr array. + """ + if defect == "array/array": + hierarchy_spec = { + "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "array_0/subarray": meta_from_array(np.arange(4), zarr_format=zarr_format), + } + elif defect == "array/group": + hierarchy_spec = { + "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "array_0/subgroup": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), + } + + msg = "Only Zarr groups can contain other nodes." + if impl == "sync": + with pytest.raises(ValueError, match=msg): + tuple(sync_group.create_hierarchy(store=store, nodes=hierarchy_spec)) + elif impl == "async": + with pytest.raises(ValueError, match=msg): + await _collect_aiterator(create_hierarchy(store=store, nodes=hierarchy_spec)) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy_invalid_mixed_format( + impl: Literal["async", "sync"], store: Store +) -> None: + """ + Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and + Zarr v3 nodes. + """ + msg = ( + "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " + "The following keys map to Zarr v2 nodes: ['v2']. " + "The following keys map to Zarr v3 nodes: ['v3']." + "Ensure that all nodes have the same Zarr format." + ) + nodes = { + "v2": GroupMetadata(zarr_format=2), + "v3": GroupMetadata(zarr_format=3), + } + if impl == "sync": + with pytest.raises(ValueError, match=re.escape(msg)): + tuple( + sync_group.create_hierarchy( + store=store, + nodes=nodes, + ) + ) + elif impl == "async": + with pytest.raises(ValueError, match=re.escape(msg)): + await _collect_aiterator( + create_hierarchy( + store=store, + nodes=nodes, + ) + ) + else: + raise ValueError(f"Invalid impl: {impl}") + + +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "root"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_rooted_hierarchy_group( + impl: Literal["async", "sync"], store: Store, zarr_format, root_key: str +) -> None: + """ + Test that the _create_rooted_hierarchy can create a group. + """ + root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} + group_names = ["a", "a/b"] + array_names = ["a/b/c", "a/b/d"] + + # just to ensure that we don't use the same name twice in tests + assert set(group_names) & set(array_names) == set() + + groups_expected_meta = { + _join_paths([root_key, node_name]): GroupMetadata( + zarr_format=zarr_format, attributes={"path": node_name} + ) + for node_name in group_names + } + + arrays_expected_meta = { + _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) + for node_name in array_names + } + + nodes_create = root_meta | groups_expected_meta | arrays_expected_meta + if impl == "sync": + g = sync_group.create_rooted_hierarchy(store=store, nodes=nodes_create) + assert isinstance(g, Group) + members = g.members(max_depth=None) + elif impl == "async": + g = await create_rooted_hierarchy(store=store, nodes=nodes_create) + assert isinstance(g, AsyncGroup) + members = await _collect_aiterator(g.members(max_depth=None)) + else: + raise ValueError(f"Unknown implementation: {impl}") + + assert g.metadata.attributes == {"path": root_key} + + members_observed_meta = {k: v.metadata for k, v in members} + members_expected_meta_relative = { + k.removeprefix(root_key).lstrip("/"): v + for k, v in (groups_expected_meta | arrays_expected_meta).items() + } + assert members_observed_meta == members_expected_meta_relative + + +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "root"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_rooted_hierarchy_array( + impl: Literal["async", "sync"], store: Store, zarr_format, root_key: str +) -> None: + """ + Test that _create_rooted_hierarchy can create an array. + """ + + root_meta = { + root_key: meta_from_array( + np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} + ) + } + nodes_create = root_meta + + if impl == "sync": + a = sync_group.create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) + assert isinstance(a, Array) + elif impl == "async": + a = await create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) + assert isinstance(a, AsyncArray) + else: + raise ValueError(f"Invalid impl: {impl}") + assert a.metadata.attributes == {"path": root_key} + + +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) -> None: + """ + Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain + a root node. + """ + zarr_format = 3 + nodes = { + "a": GroupMetadata(zarr_format=zarr_format), + "b": GroupMetadata(zarr_format=zarr_format), + } + msg = "The input does not specify a root node. " + if impl == "sync": + with pytest.raises(ValueError, match=msg): + sync_group.create_rooted_hierarchy(store=store, nodes=nodes) + elif impl == "async": + with pytest.raises(ValueError, match=msg): + await create_rooted_hierarchy(store=store, nodes=nodes) + else: + raise ValueError(f"Invalid impl: {impl}") + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_members_performance(store: Store) -> None: """ Test that the execution time of Group.members is less than the number of members times the latency for accessing each member. @@ -1505,3 +2018,35 @@ def test_group_members_concurrency_limit(store: MemoryStore) -> None: elapsed = time.time() - start assert elapsed > num_groups * get_latency + + +@pytest.mark.parametrize("option", ["array", "group", "invalid"]) +def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: + """ + Test that _build_metadata_v3 returns the correct metadata for a v3 array or group + """ + match option: + case "array": + metadata_dict = meta_from_array(np.arange(10), zarr_format=3).to_dict() + assert _build_metadata_v3(metadata_dict) == ArrayV3Metadata.from_dict(metadata_dict) + case "group": + metadata_dict = GroupMetadata(attributes={"foo": 10}, zarr_format=3).to_dict() + assert _build_metadata_v3(metadata_dict) == GroupMetadata.from_dict(metadata_dict) + case "invalid": + metadata_dict = GroupMetadata(zarr_format=3).to_dict() + metadata_dict.pop("node_type") + # TODO: fix the error message + msg = "Invalid value for 'node_type'. Expected 'array or group'. Got 'nothing (the key is missing)'." + with pytest.raises(MetadataValidationError, match=re.escape(msg)): + _build_metadata_v3(metadata_dict) + + +@pytest.mark.parametrize("roots", [("",), ("a", "b")]) +def test_get_roots(roots: tuple[str, ...]): + root_nodes = {k: GroupMetadata(attributes={"name": k}) for k in roots} + child_nodes = { + _join_paths([k, "foo"]): GroupMetadata(attributes={"name": _join_paths([k, "foo"])}) + for k in roots + } + data = root_nodes | child_nodes + assert set(_get_roots(data)) == set(roots) diff --git a/tests/test_store/test_core.py b/tests/test_store/test_core.py index 726da06a52..bce582a746 100644 --- a/tests/test_store/test_core.py +++ b/tests/test_store/test_core.py @@ -8,7 +8,7 @@ from zarr.core.common import AccessModeLiteral, ZarrFormat from zarr.storage import FsspecStore, LocalStore, MemoryStore, StoreLike, StorePath from zarr.storage._common import contains_array, contains_group, make_store_path -from zarr.storage._utils import normalize_path +from zarr.storage._utils import _join_paths, _normalize_path_keys, _normalize_paths, normalize_path @pytest.mark.parametrize("path", ["foo", "foo/bar"]) @@ -174,3 +174,48 @@ def test_normalize_path_none(): def test_normalize_path_invalid(path: str): with pytest.raises(ValueError): normalize_path(path) + + +@pytest.mark.parametrize("paths", [("", "foo"), ("foo", "bar")]) +def test_join_paths(paths: tuple[str, str]) -> None: + """ + Test that _join_paths joins paths in a way that is robust to an empty string + """ + observed = _join_paths(paths) + if paths[0] == "": + assert observed == paths[1] + else: + assert observed == "/".join(paths) + + +class TestNormalizePaths: + @staticmethod + def test_valid() -> None: + """ + Test that path normalization works as expected + """ + paths = ["a", "b", "c", "d", "", "//a///b//"] + assert _normalize_paths(paths) == tuple([normalize_path(p) for p in paths]) + + @staticmethod + @pytest.mark.parametrize("paths", [("", "/"), ("///a", "a")]) + def test_invalid(paths: tuple[str, str]) -> None: + """ + Test that name collisions after normalization raise a ``ValueError`` + """ + msg = ( + f"After normalization, the value '{paths[1]}' collides with '{paths[0]}'. " + f"Both '{paths[1]}' and '{paths[0]}' normalize to the same value: '{normalize_path(paths[0])}'. " + f"You should use either '{paths[1]}' or '{paths[0]}', but not both." + ) + with pytest.raises(ValueError, match=msg): + _normalize_paths(paths) + + +def test_normalize_path_keys(): + """ + Test that ``_normalize_path_keys`` just applies the normalize_path function to each key of its + input + """ + data = {"a": 10, "//b": 10} + assert _normalize_path_keys(data) == {normalize_path(k): v for k, v in data.items()}