Skip to content

Commit bfba0e5

Browse files
committed
use common save_metadata function for groups and arrays
1 parent 382362d commit bfba0e5

File tree

4 files changed

+115
-95
lines changed

4 files changed

+115
-95
lines changed

src/zarr/core/array.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import zarr
2626
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2727
from zarr.abc.numcodec import Numcodec, _is_numcodec
28-
from zarr.abc.store import Store, set_or_delete
2928
from zarr.codecs._v2 import V2Codec
3029
from zarr.codecs.bytes import BytesCodec
3130
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
@@ -109,6 +108,7 @@
109108
ArrayV3MetadataDict,
110109
T_ArrayMetadata,
111110
)
111+
from zarr.core.metadata.io import save_metadata
112112
from zarr.core.metadata.v2 import (
113113
CompressorLikev2,
114114
get_object_codec_id,
@@ -119,7 +119,6 @@
119119
from zarr.core.sync import sync
120120
from zarr.errors import (
121121
ArrayNotFoundError,
122-
ContainsArrayError,
123122
MetadataValidationError,
124123
ZarrDeprecationWarning,
125124
ZarrUserWarning,
@@ -140,9 +139,9 @@
140139
import numpy.typing as npt
141140

142141
from zarr.abc.codec import CodecPipeline
142+
from zarr.abc.store import Store
143143
from zarr.codecs.sharding import ShardingCodecIndexLocation
144144
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar
145-
from zarr.core.group import AsyncGroup
146145
from zarr.storage import StoreLike
147146

148147

@@ -1496,42 +1495,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14961495
"""
14971496
Asynchronously save the array metadata.
14981497
"""
1499-
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
1500-
set_awaitables = [
1501-
set_or_delete(self.store_path / key, value) for key, value in to_save.items()
1502-
]
1503-
1504-
if ensure_parents:
1505-
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
1506-
parents = _build_parents(self)
1507-
ensure_array_awaitables = []
1508-
1509-
for parent in parents:
1510-
# Error if an array already exists at any parent location. Only groups can have child nodes.
1511-
ensure_array_awaitables.append(
1512-
ensure_no_existing_node(
1513-
parent.store_path, metadata.zarr_format, node_type="array"
1514-
)
1515-
)
1516-
set_awaitables.extend(
1517-
[
1518-
(parent.store_path / key).set_if_not_exists(value)
1519-
for key, value in parent.metadata.to_buffer_dict(
1520-
cpu_buffer_prototype
1521-
).items()
1522-
]
1523-
)
1524-
1525-
# Checks for parent arrays must happen first, before any metadata is modified
1526-
try:
1527-
await gather(*ensure_array_awaitables)
1528-
except ContainsArrayError as e:
1529-
set_awaitables = [] # clear awaitables to avoid printed RuntimeWarning: coroutine was never awaited
1530-
raise ValueError(
1531-
f"A parent of {self.store_path} is an array - only groups may have child nodes."
1532-
) from e
1533-
1534-
await gather(*set_awaitables)
1498+
await save_metadata(self.store_path, metadata, ensure_parents=ensure_parents)
15351499

15361500
async def _set_selection(
15371501
self,
@@ -3899,37 +3863,6 @@ async def chunks_initialized(
38993863
)
39003864

39013865

3902-
def _build_parents(
3903-
node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup,
3904-
) -> list[AsyncGroup]:
3905-
from zarr.core.group import AsyncGroup, GroupMetadata
3906-
3907-
store = node.store_path.store
3908-
path = node.store_path.path
3909-
if not path:
3910-
return []
3911-
3912-
required_parts = path.split("/")[:-1]
3913-
parents = [
3914-
# the root group
3915-
AsyncGroup(
3916-
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
3917-
store_path=StorePath(store=store, path=""),
3918-
)
3919-
]
3920-
3921-
for i, part in enumerate(required_parts):
3922-
p = "/".join(required_parts[:i] + [part])
3923-
parents.append(
3924-
AsyncGroup(
3925-
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
3926-
store_path=StorePath(store=store, path=p),
3927-
)
3928-
)
3929-
3930-
return parents
3931-
3932-
39333866
FiltersLike: TypeAlias = (
39343867
Iterable[dict[str, JSON] | ArrayArrayCodec | Numcodec]
39353868
| ArrayArrayCodec

src/zarr/core/group.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
FiltersLike,
2929
SerializerLike,
3030
ShardsLike,
31-
_build_parents,
3231
_parse_deprecated_compressor,
3332
create_array,
3433
)
@@ -49,6 +48,7 @@
4948
)
5049
from zarr.core.config import config
5150
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
51+
from zarr.core.metadata.io import save_metadata
5252
from zarr.core.sync import SyncMixin, sync
5353
from zarr.errors import (
5454
ContainsArrayError,
@@ -808,22 +808,7 @@ async def get(
808808
return default
809809

810810
async def _save_metadata(self, ensure_parents: bool = False) -> None:
811-
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
812-
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
813-
814-
if ensure_parents:
815-
parents = _build_parents(self)
816-
for parent in parents:
817-
awaitables.extend(
818-
[
819-
(parent.store_path / key).set_if_not_exists(value)
820-
for key, value in parent.metadata.to_buffer_dict(
821-
default_buffer_prototype()
822-
).items()
823-
]
824-
)
825-
826-
await asyncio.gather(*awaitables)
811+
await save_metadata(self.store_path, self.metadata, ensure_parents=ensure_parents)
827812

828813
@property
829814
def path(self) -> str:

src/zarr/core/metadata/io.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import TYPE_CHECKING
5+
6+
from zarr.abc.store import set_or_delete
7+
from zarr.core.buffer.core import default_buffer_prototype
8+
from zarr.errors import ContainsArrayError
9+
from zarr.storage._common import StorePath, ensure_no_existing_node
10+
11+
if TYPE_CHECKING:
12+
from zarr.core.common import ZarrFormat
13+
from zarr.core.group import AsyncGroup, GroupMetadata
14+
from zarr.core.metadata import ArrayMetadata
15+
16+
17+
def _build_parents(store_path: StorePath, zarr_format: ZarrFormat) -> list[AsyncGroup]:
18+
from zarr.core.group import AsyncGroup, GroupMetadata
19+
20+
store = store_path.store
21+
path = store_path.path
22+
if not path:
23+
return []
24+
25+
required_parts = path.split("/")[:-1]
26+
parents = [
27+
# the root group
28+
AsyncGroup(
29+
metadata=GroupMetadata(zarr_format=zarr_format),
30+
store_path=StorePath(store=store, path=""),
31+
)
32+
]
33+
34+
for i, part in enumerate(required_parts):
35+
p = "/".join(required_parts[:i] + [part])
36+
parents.append(
37+
AsyncGroup(
38+
metadata=GroupMetadata(zarr_format=zarr_format),
39+
store_path=StorePath(store=store, path=p),
40+
)
41+
)
42+
43+
return parents
44+
45+
46+
async def save_metadata(
47+
store_path: StorePath, metadata: ArrayMetadata | GroupMetadata, ensure_parents: bool = False
48+
) -> None:
49+
"""Asynchronously save the array or group metadata.
50+
51+
Parameters
52+
----------
53+
store_path : StorePath
54+
Location to save metadata.
55+
metadata : ArrayMetadata | GroupMetadata
56+
Metadata to save.
57+
ensure_parents : bool, optional
58+
Whether to create any missing parent groups
59+
60+
Raises
61+
------
62+
ValueError
63+
"""
64+
to_save = metadata.to_buffer_dict(default_buffer_prototype())
65+
set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()]
66+
67+
if ensure_parents:
68+
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
69+
parents = _build_parents(store_path, metadata.zarr_format)
70+
ensure_array_awaitables = []
71+
72+
for parent in parents:
73+
# Error if an array already exists at any parent location. Only groups can have child nodes.
74+
ensure_array_awaitables.append(
75+
ensure_no_existing_node(parent.store_path, metadata.zarr_format, node_type="array")
76+
)
77+
set_awaitables.extend(
78+
[
79+
(parent.store_path / key).set_if_not_exists(value)
80+
for key, value in parent.metadata.to_buffer_dict(
81+
default_buffer_prototype()
82+
).items()
83+
]
84+
)
85+
86+
# Checks for parent arrays must happen first, before any metadata is modified
87+
try:
88+
await asyncio.gather(*ensure_array_awaitables)
89+
except ContainsArrayError as e:
90+
set_awaitables = [] # clear awaitables to avoid printed RuntimeWarning: coroutine was never awaited
91+
raise ValueError(
92+
f"A parent of {store_path} is an array - only groups may have child nodes."
93+
) from e
94+
95+
await asyncio.gather(*set_awaitables)

src/zarr/storage/_common.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,21 @@ async def ensure_no_existing_node(
427427
elif zarr_format == 3:
428428
extant_node = await _contains_node_v3(store_path)
429429

430-
if extant_node == "array" and node_type != "group":
431-
raise ContainsArrayError(store_path.store, store_path.path)
432-
elif extant_node == "group" and node_type != "array":
433-
raise ContainsGroupError(store_path.store, store_path.path)
434-
elif extant_node == "nothing":
435-
return
436-
msg = f"Invalid value for extant_node: {extant_node}"
437-
raise ValueError(msg)
430+
match extant_node:
431+
case "array":
432+
if node_type != "group":
433+
raise ContainsArrayError(store_path.store, store_path.path)
434+
435+
case "group":
436+
if node_type != "array":
437+
raise ContainsGroupError(store_path.store, store_path.path)
438+
439+
case "nothing":
440+
return
441+
442+
case _:
443+
msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable]
444+
raise ValueError(msg)
438445

439446

440447
async def _contains_node_v3(store_path: StorePath) -> Literal["array", "group", "nothing"]:

0 commit comments

Comments
 (0)