Skip to content

Commit 2edc548

Browse files
Ensure parents are created when creating a node (#2262)
* Ensure parents are created when creating a node This updates our Array and Group creation methods to ensure that parents implicitly defined through a nested path are also created. To accomplish this semi-safely and efficiently, we require a new setdefulat method on the Store class. * use the API * fixed logging store * Update src/zarr/testing/store.py * fixes * fixup * fixes * pre-commit --------- Co-authored-by: Joe Hamman <[email protected]>
1 parent 8e2c660 commit 2edc548

File tree

14 files changed

+221
-12
lines changed

14 files changed

+221
-12
lines changed

src/zarr/abc/store.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,22 @@ async def set(self, key: str, value: Buffer) -> None:
172172
"""
173173
...
174174

175+
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
176+
"""
177+
Store a key to ``value`` if the key is not already present.
178+
179+
Parameters
180+
-----------
181+
key : str
182+
value : Buffer
183+
"""
184+
# Note for implementers: the default implementation provided here
185+
# is not safe for concurrent writers. There's a race condition between
186+
# the `exists` check and the `set` where another writer could set some
187+
# value at `key` or delete `key`.
188+
if not await self.exists(key):
189+
await self.set(key, value)
190+
175191
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
176192
"""
177193
Insert multiple (key, value) pairs into storage.
@@ -297,6 +313,8 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
297313

298314
async def delete(self) -> None: ...
299315

316+
async def set_if_not_exists(self, default: Buffer) -> None: ...
317+
300318

301319
async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
302320
if value is None:

src/zarr/codecs/sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
9797
async def delete(self) -> None:
9898
del self.shard_dict[self.chunk_coords]
9999

100+
async def set_if_not_exists(self, default: Buffer) -> None:
101+
self.shard_dict.setdefault(self.chunk_coords, default)
102+
100103

101104
class _ShardIndex(NamedTuple):
102105
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)

src/zarr/core/array.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from zarr.codecs import BytesCodec
1414
from zarr.codecs._v2 import V2Compressor, V2Filters
1515
from zarr.core.attributes import Attributes
16-
from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
16+
from zarr.core.buffer import (
17+
BufferPrototype,
18+
NDArrayLike,
19+
NDBuffer,
20+
default_buffer_prototype,
21+
)
1722
from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks
1823
from zarr.core.chunk_key_encodings import (
1924
ChunkKeyEncoding,
@@ -71,6 +76,7 @@
7176
from collections.abc import Iterable, Iterator, Sequence
7277

7378
from zarr.abc.codec import Codec, CodecPipeline
79+
from zarr.core.group import AsyncGroup
7480
from zarr.core.metadata.common import ArrayMetadata
7581

7682
# Array and AsyncArray are defined in the base ``zarr`` namespace
@@ -337,7 +343,7 @@ async def _create_v3(
337343
)
338344

339345
array = cls(metadata=metadata, store_path=store_path)
340-
await array._save_metadata(metadata)
346+
await array._save_metadata(metadata, ensure_parents=True)
341347
return array
342348

343349
@classmethod
@@ -376,7 +382,7 @@ async def _create_v2(
376382
attributes=attributes,
377383
)
378384
array = cls(metadata=metadata, store_path=store_path)
379-
await array._save_metadata(metadata)
385+
await array._save_metadata(metadata, ensure_parents=True)
380386
return array
381387

382388
@classmethod
@@ -621,9 +627,24 @@ async def getitem(
621627
)
622628
return await self._get_selection(indexer, prototype=prototype)
623629

624-
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
630+
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
625631
to_save = metadata.to_buffer_dict(default_buffer_prototype())
626632
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
633+
634+
if ensure_parents:
635+
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
636+
parents = _build_parents(self)
637+
638+
for parent in parents:
639+
awaitables.extend(
640+
[
641+
(parent.store_path / key).set_if_not_exists(value)
642+
for key, value in parent.metadata.to_buffer_dict(
643+
default_buffer_prototype()
644+
).items()
645+
]
646+
)
647+
627648
await gather(*awaitables)
628649

629650
async def _set_selection(
@@ -2354,3 +2375,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
23542375
out.append(chunk_key)
23552376

23562377
return tuple(out)
2378+
2379+
2380+
def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
2381+
from zarr.core.group import AsyncGroup, GroupMetadata
2382+
2383+
required_parts = node.store_path.path.split("/")[:-1]
2384+
parents = []
2385+
2386+
for i, part in enumerate(required_parts):
2387+
path = "/".join(required_parts[:i] + [part])
2388+
parents.append(
2389+
AsyncGroup(
2390+
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
2391+
store_path=StorePath(store=node.store_path.store, path=path),
2392+
)
2393+
)
2394+
2395+
return parents

src/zarr/core/group.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import zarr.api.asynchronous as async_api
1414
from zarr.abc.metadata import Metadata
1515
from zarr.abc.store import Store, set_or_delete
16-
from zarr.core.array import Array, AsyncArray
16+
from zarr.core.array import Array, AsyncArray, _build_parents
1717
from zarr.core.attributes import Attributes
1818
from zarr.core.buffer import default_buffer_prototype
1919
from zarr.core.common import (
@@ -144,7 +144,7 @@ async def from_store(
144144
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
145145
store_path=store_path,
146146
)
147-
await group._save_metadata()
147+
await group._save_metadata(ensure_parents=True)
148148
return group
149149

150150
@classmethod
@@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None:
279279
else:
280280
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
281281

282-
async def _save_metadata(self) -> None:
282+
async def _save_metadata(self, ensure_parents: bool = False) -> None:
283283
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
284284
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
285+
286+
if ensure_parents:
287+
parents = _build_parents(self)
288+
for parent in parents:
289+
awaitables.extend(
290+
[
291+
(parent.store_path / key).set_if_not_exists(value)
292+
for key, value in parent.metadata.to_buffer_dict(
293+
default_buffer_prototype()
294+
).items()
295+
]
296+
)
297+
285298
await asyncio.gather(*awaitables)
286299

287300
@property

src/zarr/store/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
5151
async def delete(self) -> None:
5252
await self.store.delete(self.path)
5353

54+
async def set_if_not_exists(self, default: Buffer) -> None:
55+
await self.store.set_if_not_exists(self.path, default)
56+
5457
async def exists(self) -> bool:
5558
return await self.store.exists(self.path)
5659

src/zarr/store/local.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _put(
6060
path: Path,
6161
value: Buffer,
6262
start: int | None = None,
63+
exclusive: bool = False,
6364
) -> int | None:
6465
path.parent.mkdir(parents=True, exist_ok=True)
6566
if start is not None:
@@ -68,7 +69,13 @@ def _put(
6869
f.write(value.as_numpy_array().tobytes())
6970
return None
7071
else:
71-
return path.write_bytes(value.as_numpy_array().tobytes())
72+
view = memoryview(value.as_numpy_array().tobytes())
73+
if exclusive:
74+
mode = "xb"
75+
else:
76+
mode = "wb"
77+
with path.open(mode=mode) as f:
78+
return f.write(view)
7279

7380

7481
class LocalStore(Store):
@@ -152,14 +159,23 @@ async def get_partial_values(
152159
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
153160

154161
async def set(self, key: str, value: Buffer) -> None:
162+
return await self._set(key, value)
163+
164+
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
165+
try:
166+
return await self._set(key, value, exclusive=True)
167+
except FileExistsError:
168+
pass
169+
170+
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
155171
if not self._is_open:
156172
await self._open()
157173
self._check_writable()
158174
assert isinstance(key, str)
159175
if not isinstance(value, Buffer):
160176
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
161177
path = self.root / key
162-
await to_thread(_put, path, value)
178+
await to_thread(_put, path, value, start=None, exclusive=exclusive)
163179

164180
async def set_partial_values(
165181
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]

src/zarr/store/logging.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
from zarr.abc.store import AccessMode, ByteRangeRequest, Store
11+
from zarr.core.buffer import Buffer
1112

1213
if TYPE_CHECKING:
1314
from collections.abc import AsyncGenerator, Generator, Iterable
@@ -138,6 +139,10 @@ async def set(self, key: str, value: Buffer) -> None:
138139
with self.log():
139140
return await self._store.set(key=key, value=value)
140141

142+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
143+
with self.log():
144+
return await self._store.set_if_not_exists(key=key, value=default)
145+
141146
async def delete(self, key: str) -> None:
142147
with self.log():
143148
return await self._store.delete(key=key)

src/zarr/store/memory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool:
8585
return key in self._store_dict
8686

8787
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
88-
if not self._is_open:
89-
await self._open()
9088
self._check_writable()
89+
await self._ensure_open()
9190
assert isinstance(key, str)
9291
if not isinstance(value, Buffer):
9392
raise TypeError(f"Expected Buffer. Got {type(value)}.")
@@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
9998
else:
10099
self._store_dict[key] = value
101100

101+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
102+
self._check_writable()
103+
await self._ensure_open()
104+
self._store_dict.setdefault(key, default)
105+
102106
async def delete(self, key: str) -> None:
103107
self._check_writable()
104108
try:

src/zarr/store/remote.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import fsspec
66

77
from zarr.abc.store import ByteRangeRequest, Store
8+
from zarr.core.buffer import Buffer
89
from zarr.store.common import _dereference_path
910

1011
if TYPE_CHECKING:

src/zarr/store/zip.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ async def set(self, key: str, value: Buffer) -> None:
188188
async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
189189
raise NotImplementedError
190190

191+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
192+
self._check_writable()
193+
with self._lock:
194+
members = self._zf.namelist()
195+
if key not in members:
196+
self._set(key, default)
197+
191198
async def delete(self, key: str) -> None:
192199
raise NotImplementedError
193200

0 commit comments

Comments
 (0)