diff --git a/changes/2582.bugfix.rst b/changes/2582.bugfix.rst new file mode 100644 index 0000000000..1c8c7fab5c --- /dev/null +++ b/changes/2582.bugfix.rst @@ -0,0 +1,2 @@ +Prevents creation of groups (.create_group) or arrays (.create_array) as children +of an existing array. diff --git a/docs/user-guide/arrays.rst b/docs/user-guide/arrays.rst index 257fac450c..a498cb44a3 100644 --- a/docs/user-guide/arrays.rst +++ b/docs/user-guide/arrays.rst @@ -567,11 +567,12 @@ Any combination of integer and slice can be used for block indexing:: >>> >>> root = zarr.create_group('data/example-19.zarr') >>> foo = root.create_array(name='foo', shape=(1000, 100), chunks=(10, 10), dtype='float32') - >>> bar = root.create_array(name='foo/bar', shape=(100,), dtype='int32') + >>> bar = root.create_array(name='bar', shape=(100,), dtype='int32') >>> foo[:, :] = np.random.random((1000, 100)) >>> bar[:] = np.arange(100) >>> root.tree() / + ├── bar (100,) int32 └── foo (1000, 100) float32 diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index e5fa451914..793d1a034b 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -25,7 +25,6 @@ import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec -from zarr.abc.store import Store, set_or_delete from zarr.codecs._v2 import V2Codec from zarr.codecs.bytes import BytesCodec from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec @@ -110,6 +109,7 @@ ArrayV3MetadataDict, T_ArrayMetadata, ) +from zarr.core.metadata.io import save_metadata from zarr.core.metadata.v2 import ( CompressorLikev2, get_object_codec_id, @@ -140,9 +140,9 @@ import numpy.typing as npt from zarr.abc.codec import CodecPipeline + from zarr.abc.store import Store from zarr.codecs.sharding import ShardingCodecIndexLocation from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar - from zarr.core.group import AsyncGroup from zarr.storage import StoreLike @@ -1639,24 +1639,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F """ Asynchronously save the array metadata. """ - to_save = metadata.to_buffer_dict(cpu_buffer_prototype) - awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] - - if ensure_parents: - # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups. - parents = _build_parents(self) - - for parent in parents: - awaitables.extend( - [ - (parent.store_path / key).set_if_not_exists(value) - for key, value in parent.metadata.to_buffer_dict( - cpu_buffer_prototype - ).items() - ] - ) - - await gather(*awaitables) + await save_metadata(self.store_path, metadata, ensure_parents=ensure_parents) async def _set_selection( self, @@ -4121,37 +4104,6 @@ async def _shards_initialized( ) -def _build_parents( - node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup, -) -> list[AsyncGroup]: - from zarr.core.group import AsyncGroup, GroupMetadata - - store = node.store_path.store - path = node.store_path.path - if not path: - return [] - - required_parts = path.split("/")[:-1] - parents = [ - # the root group - AsyncGroup( - metadata=GroupMetadata(zarr_format=node.metadata.zarr_format), - store_path=StorePath(store=store, path=""), - ) - ] - - for i, part in enumerate(required_parts): - p = "/".join(required_parts[:i] + [part]) - parents.append( - AsyncGroup( - metadata=GroupMetadata(zarr_format=node.metadata.zarr_format), - store_path=StorePath(store=store, path=p), - ) - ) - - return parents - - FiltersLike: TypeAlias = ( Iterable[dict[str, JSON] | ArrayArrayCodec | Numcodec] | ArrayArrayCodec diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index f5bb14c48e..1c41a8a4a8 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -28,7 +28,6 @@ FiltersLike, SerializerLike, ShardsLike, - _build_parents, _parse_deprecated_compressor, create_array, ) @@ -49,6 +48,7 @@ ) from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata +from zarr.core.metadata.io import save_metadata from zarr.core.sync import SyncMixin, sync from zarr.errors import ( ContainsArrayError, @@ -818,22 +818,7 @@ async def get( return default async def _save_metadata(self, ensure_parents: bool = False) -> None: - to_save = self.metadata.to_buffer_dict(default_buffer_prototype()) - awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] - - if ensure_parents: - parents = _build_parents(self) - for parent in parents: - awaitables.extend( - [ - (parent.store_path / key).set_if_not_exists(value) - for key, value in parent.metadata.to_buffer_dict( - default_buffer_prototype() - ).items() - ] - ) - - await asyncio.gather(*awaitables) + await save_metadata(self.store_path, self.metadata, ensure_parents=ensure_parents) @property def path(self) -> str: diff --git a/src/zarr/core/metadata/io.py b/src/zarr/core/metadata/io.py new file mode 100644 index 0000000000..7b63f5493b --- /dev/null +++ b/src/zarr/core/metadata/io.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from zarr.abc.store import set_or_delete +from zarr.core.buffer.core import default_buffer_prototype +from zarr.errors import ContainsArrayError +from zarr.storage._common import StorePath, ensure_no_existing_node + +if TYPE_CHECKING: + from zarr.core.common import ZarrFormat + from zarr.core.group import GroupMetadata + from zarr.core.metadata import ArrayMetadata + + +def _build_parents(store_path: StorePath, zarr_format: ZarrFormat) -> dict[str, GroupMetadata]: + from zarr.core.group import GroupMetadata + + path = store_path.path + if not path: + return {} + + required_parts = path.split("/")[:-1] + + # the root group + parents = {"": GroupMetadata(zarr_format=zarr_format)} + + for i, part in enumerate(required_parts): + parent_path = "/".join(required_parts[:i] + [part]) + parents[parent_path] = GroupMetadata(zarr_format=zarr_format) + + return parents + + +async def save_metadata( + store_path: StorePath, metadata: ArrayMetadata | GroupMetadata, ensure_parents: bool = False +) -> None: + """Asynchronously save the array or group metadata. + + Parameters + ---------- + store_path : StorePath + Location to save metadata. + metadata : ArrayMetadata | GroupMetadata + Metadata to save. + ensure_parents : bool, optional + Create any missing parent groups, and check no existing parents are arrays. + + Raises + ------ + ValueError + """ + to_save = metadata.to_buffer_dict(default_buffer_prototype()) + set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()] + + if ensure_parents: + # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups. + parents = _build_parents(store_path, metadata.zarr_format) + ensure_array_awaitables = [] + + for parent_path, parent_metadata in parents.items(): + parent_store_path = StorePath(store_path.store, parent_path) + + # Error if an array already exists at any parent location. Only groups can have child nodes. + ensure_array_awaitables.append( + ensure_no_existing_node( + parent_store_path, parent_metadata.zarr_format, node_type="array" + ) + ) + set_awaitables.extend( + [ + (parent_store_path / key).set_if_not_exists(value) + for key, value in parent_metadata.to_buffer_dict( + default_buffer_prototype() + ).items() + ] + ) + + # Checks for parent arrays must happen first, before any metadata is modified + try: + await asyncio.gather(*ensure_array_awaitables) + except ContainsArrayError as e: + # clear awaitables to avoid RuntimeWarning: coroutine was never awaited + for awaitable in set_awaitables: + awaitable.close() + + raise ValueError( + f"A parent of {store_path} is an array - only groups may have child nodes." + ) from e + + await asyncio.gather(*set_awaitables) diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 6febb08281..4b1f5e4ae3 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -435,7 +435,11 @@ def _is_fsspec_uri(uri: str) -> bool: return "://" in uri or ("::" in uri and "local://" not in uri) -async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None: +async def ensure_no_existing_node( + store_path: StorePath, + zarr_format: ZarrFormat, + node_type: Literal["array", "group"] | None = None, +) -> None: """ Check if a store_path is safe for array / group creation. Returns `None` or raises an exception. @@ -446,6 +450,8 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat The storage location to check. zarr_format : ZarrFormat The Zarr format to check. + node_type : str | None, optional + Raise an error if an "array", or "group" exists. By default (when None), raises an error for either. Raises ------ @@ -456,16 +462,23 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat elif zarr_format == 3: extant_node = await _contains_node_v3(store_path) - if extant_node == "array": - msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}." - raise ContainsArrayError(msg) - elif extant_node == "group": - msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}." - raise ContainsGroupError(msg) - elif extant_node == "nothing": - return - msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable] - raise ValueError(msg) + match extant_node: + case "array": + if node_type != "group": + msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}." + raise ContainsArrayError(msg) + + case "group": + if node_type != "array": + msg = f"A group exists in store {store_path.store!r} at path {store_path.path!r}." + raise ContainsGroupError(msg) + + case "nothing": + return + + case _: + msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable] + raise ValueError(msg) async def _contains_node_v3(store_path: StorePath) -> Literal["array", "group", "nothing"]: diff --git a/tests/test_group.py b/tests/test_group.py index e7ce2bad16..6f1f4e68fa 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -761,6 +761,24 @@ def test_group_create_array( assert np.array_equal(array[:], data) +@pytest.mark.parametrize("method", ["create_array", "create_group"]) +def test_create_with_parent_array(store: Store, zarr_format: ZarrFormat, method: str): + """Test that groups/arrays cannot be created under a parent array.""" + + # create a group with a child array + group = Group.from_store(store, zarr_format=zarr_format) + group.create_array(name="arr_1", shape=(10, 10), dtype="uint8") + + error_msg = r"A parent of .* is an array - only groups may have child nodes." + if method == "create_array": + with pytest.raises(ValueError, match=error_msg): + group.create_array("arr_1/group_1/group_2/arr_2", shape=(10, 10), dtype="uint8") + + else: + with pytest.raises(ValueError, match=error_msg): + group.create_group("arr_1/group_1/group_2/group_3") + + LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"] diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c6a792b0e5..609db6cdce 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -2110,14 +2110,14 @@ async def test_async_oindex(self, store, indexer, expected): @pytest.mark.asyncio async def test_async_oindex_with_zarr_array(self, store): - z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8") + group = zarr.create_group(store=store, zarr_format=3) + + z1 = group.create_array(name="z1", shape=(2, 2), chunks=(1, 1), dtype="i8") z1[...] = np.array([[1, 2], [3, 4]]) async_zarr = z1._async_array # create boolean zarr array to index with - z2 = zarr.create_array( - store=store, name="z2", shape=(2,), chunks=(1,), zarr_format=3, dtype="?" - ) + z2 = group.create_array(name="z2", shape=(2,), chunks=(1,), dtype="?") z2[...] = np.array([True, False]) result = await async_zarr.oindex.getitem(z2) @@ -2143,14 +2143,14 @@ async def test_async_vindex(self, store, indexer, expected): @pytest.mark.asyncio async def test_async_vindex_with_zarr_array(self, store): - z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8") + group = zarr.create_group(store=store, zarr_format=3) + + z1 = group.create_array(name="z1", shape=(2, 2), chunks=(1, 1), dtype="i8") z1[...] = np.array([[1, 2], [3, 4]]) async_zarr = z1._async_array # create boolean zarr array to index with - z2 = zarr.create_array( - store=store, name="z2", shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="?" - ) + z2 = group.create_array(name="z2", shape=(2, 2), chunks=(1, 1), dtype="?") z2[...] = np.array([[False, True], [False, True]]) result = await async_zarr.vindex.getitem(z2)