Skip to content

Commit 382362d

Browse files
committed
restrict arrays as parents of other arrays
1 parent c019a5f commit 382362d

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

src/zarr/core/array.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
from zarr.core.sync import sync
120120
from zarr.errors import (
121121
ArrayNotFoundError,
122+
ContainsArrayError,
122123
MetadataValidationError,
123124
ZarrDeprecationWarning,
124125
ZarrUserWarning,
@@ -1496,14 +1497,23 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14961497
Asynchronously save the array metadata.
14971498
"""
14981499
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
1499-
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
1500+
set_awaitables = [
1501+
set_or_delete(self.store_path / key, value) for key, value in to_save.items()
1502+
]
15001503

15011504
if ensure_parents:
15021505
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
15031506
parents = _build_parents(self)
1507+
ensure_array_awaitables = []
15041508

15051509
for parent in parents:
1506-
awaitables.extend(
1510+
# Error if an array already exists at any parent location. Only groups can have child nodes.
1511+
ensure_array_awaitables.append(
1512+
ensure_no_existing_node(
1513+
parent.store_path, metadata.zarr_format, node_type="array"
1514+
)
1515+
)
1516+
set_awaitables.extend(
15071517
[
15081518
(parent.store_path / key).set_if_not_exists(value)
15091519
for key, value in parent.metadata.to_buffer_dict(
@@ -1512,7 +1522,16 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
15121522
]
15131523
)
15141524

1515-
await gather(*awaitables)
1525+
# Checks for parent arrays must happen first, before any metadata is modified
1526+
try:
1527+
await gather(*ensure_array_awaitables)
1528+
except ContainsArrayError as e:
1529+
set_awaitables = [] # clear awaitables to avoid printed RuntimeWarning: coroutine was never awaited
1530+
raise ValueError(
1531+
f"A parent of {self.store_path} is an array - only groups may have child nodes."
1532+
) from e
1533+
1534+
await gather(*set_awaitables)
15161535

15171536
async def _set_selection(
15181537
self,

src/zarr/storage/_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,11 @@ def _is_fsspec_uri(uri: str) -> bool:
400400
return "://" in uri or ("::" in uri and "local://" not in uri)
401401

402402

403-
async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None:
403+
async def ensure_no_existing_node(
404+
store_path: StorePath,
405+
zarr_format: ZarrFormat,
406+
node_type: Literal["array", "group"] | None = None,
407+
) -> None:
404408
"""
405409
Check if a store_path is safe for array / group creation.
406410
Returns `None` or raises an exception.
@@ -411,6 +415,8 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat
411415
The storage location to check.
412416
zarr_format : ZarrFormat
413417
The Zarr format to check.
418+
node_type : str | None, optional
419+
Raise an error if an "array", or "group" exists. By default (when None), raises an error for either.
414420
415421
Raises
416422
------
@@ -421,13 +427,13 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat
421427
elif zarr_format == 3:
422428
extant_node = await _contains_node_v3(store_path)
423429

424-
if extant_node == "array":
430+
if extant_node == "array" and node_type != "group":
425431
raise ContainsArrayError(store_path.store, store_path.path)
426-
elif extant_node == "group":
432+
elif extant_node == "group" and node_type != "array":
427433
raise ContainsGroupError(store_path.store, store_path.path)
428434
elif extant_node == "nothing":
429435
return
430-
msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable]
436+
msg = f"Invalid value for extant_node: {extant_node}"
431437
raise ValueError(msg)
432438

433439

0 commit comments

Comments
 (0)