Skip to content

Commit bb29be4

Browse files
authored
Merge branch 'v3' into fix/modes
2 parents a4dc888 + 2edc548 commit bb29be4

File tree

13 files changed

+222
-11
lines changed

13 files changed

+222
-11
lines changed

src/zarr/abc/store.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,22 @@ async def set(self, key: str, value: Buffer) -> None:
197197
"""
198198
...
199199

200+
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
201+
"""
202+
Store a key to ``value`` if the key is not already present.
203+
204+
Parameters
205+
-----------
206+
key : str
207+
value : Buffer
208+
"""
209+
# Note for implementers: the default implementation provided here
210+
# is not safe for concurrent writers. There's a race condition between
211+
# the `exists` check and the `set` where another writer could set some
212+
# value at `key` or delete `key`.
213+
if not await self.exists(key):
214+
await self.set(key, value)
215+
200216
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
201217
"""
202218
Insert multiple (key, value) pairs into storage.
@@ -322,6 +338,8 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
322338

323339
async def delete(self) -> None: ...
324340

341+
async def set_if_not_exists(self, default: Buffer) -> None: ...
342+
325343

326344
async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
327345
if value is None:

src/zarr/codecs/sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
9797
async def delete(self) -> None:
9898
del self.shard_dict[self.chunk_coords]
9999

100+
async def set_if_not_exists(self, default: Buffer) -> None:
101+
self.shard_dict.setdefault(self.chunk_coords, default)
102+
100103

101104
class _ShardIndex(NamedTuple):
102105
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)

src/zarr/core/array.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from zarr.codecs import BytesCodec
1414
from zarr.codecs._v2 import V2Compressor, V2Filters
1515
from zarr.core.attributes import Attributes
16-
from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
16+
from zarr.core.buffer import (
17+
BufferPrototype,
18+
NDArrayLike,
19+
NDBuffer,
20+
default_buffer_prototype,
21+
)
1722
from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks
1823
from zarr.core.chunk_key_encodings import (
1924
ChunkKeyEncoding,
@@ -71,6 +76,7 @@
7176
from collections.abc import Iterable, Iterator, Sequence
7277

7378
from zarr.abc.codec import Codec, CodecPipeline
79+
from zarr.core.group import AsyncGroup
7480
from zarr.core.metadata.common import ArrayMetadata
7581

7682
# Array and AsyncArray are defined in the base ``zarr`` namespace
@@ -337,7 +343,7 @@ async def _create_v3(
337343
)
338344

339345
array = cls(metadata=metadata, store_path=store_path)
340-
await array._save_metadata(metadata)
346+
await array._save_metadata(metadata, ensure_parents=True)
341347
return array
342348

343349
@classmethod
@@ -376,7 +382,7 @@ async def _create_v2(
376382
attributes=attributes,
377383
)
378384
array = cls(metadata=metadata, store_path=store_path)
379-
await array._save_metadata(metadata)
385+
await array._save_metadata(metadata, ensure_parents=True)
380386
return array
381387

382388
@classmethod
@@ -621,9 +627,24 @@ async def getitem(
621627
)
622628
return await self._get_selection(indexer, prototype=prototype)
623629

624-
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
630+
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
625631
to_save = metadata.to_buffer_dict(default_buffer_prototype())
626632
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
633+
634+
if ensure_parents:
635+
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
636+
parents = _build_parents(self)
637+
638+
for parent in parents:
639+
awaitables.extend(
640+
[
641+
(parent.store_path / key).set_if_not_exists(value)
642+
for key, value in parent.metadata.to_buffer_dict(
643+
default_buffer_prototype()
644+
).items()
645+
]
646+
)
647+
627648
await gather(*awaitables)
628649

629650
async def _set_selection(
@@ -2354,3 +2375,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
23542375
out.append(chunk_key)
23552376

23562377
return tuple(out)
2378+
2379+
2380+
def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
2381+
from zarr.core.group import AsyncGroup, GroupMetadata
2382+
2383+
required_parts = node.store_path.path.split("/")[:-1]
2384+
parents = []
2385+
2386+
for i, part in enumerate(required_parts):
2387+
path = "/".join(required_parts[:i] + [part])
2388+
parents.append(
2389+
AsyncGroup(
2390+
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
2391+
store_path=StorePath(store=node.store_path.store, path=path),
2392+
)
2393+
)
2394+
2395+
return parents

src/zarr/core/group.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import zarr.api.asynchronous as async_api
1414
from zarr.abc.metadata import Metadata
1515
from zarr.abc.store import Store, set_or_delete
16-
from zarr.core.array import Array, AsyncArray
16+
from zarr.core.array import Array, AsyncArray, _build_parents
1717
from zarr.core.attributes import Attributes
1818
from zarr.core.buffer import default_buffer_prototype
1919
from zarr.core.common import (
@@ -144,7 +144,7 @@ async def from_store(
144144
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
145145
store_path=store_path,
146146
)
147-
await group._save_metadata()
147+
await group._save_metadata(ensure_parents=True)
148148
return group
149149

150150
@classmethod
@@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None:
279279
else:
280280
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
281281

282-
async def _save_metadata(self) -> None:
282+
async def _save_metadata(self, ensure_parents: bool = False) -> None:
283283
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
284284
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
285+
286+
if ensure_parents:
287+
parents = _build_parents(self)
288+
for parent in parents:
289+
awaitables.extend(
290+
[
291+
(parent.store_path / key).set_if_not_exists(value)
292+
for key, value in parent.metadata.to_buffer_dict(
293+
default_buffer_prototype()
294+
).items()
295+
]
296+
)
297+
285298
await asyncio.gather(*awaitables)
286299

287300
@property

src/zarr/store/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
5151
async def delete(self) -> None:
5252
await self.store.delete(self.path)
5353

54+
async def set_if_not_exists(self, default: Buffer) -> None:
55+
await self.store.set_if_not_exists(self.path, default)
56+
5457
async def exists(self) -> bool:
5558
return await self.store.exists(self.path)
5659

src/zarr/store/local.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _put(
6060
path: Path,
6161
value: Buffer,
6262
start: int | None = None,
63+
exclusive: bool = False,
6364
) -> int | None:
6465
path.parent.mkdir(parents=True, exist_ok=True)
6566
if start is not None:
@@ -68,7 +69,13 @@ def _put(
6869
f.write(value.as_numpy_array().tobytes())
6970
return None
7071
else:
71-
return path.write_bytes(value.as_numpy_array().tobytes())
72+
view = memoryview(value.as_numpy_array().tobytes())
73+
if exclusive:
74+
mode = "xb"
75+
else:
76+
mode = "wb"
77+
with path.open(mode=mode) as f:
78+
return f.write(view)
7279

7380

7481
class LocalStore(Store):
@@ -155,14 +162,23 @@ async def get_partial_values(
155162
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
156163

157164
async def set(self, key: str, value: Buffer) -> None:
165+
return await self._set(key, value)
166+
167+
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
168+
try:
169+
return await self._set(key, value, exclusive=True)
170+
except FileExistsError:
171+
pass
172+
173+
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
158174
if not self._is_open:
159175
await self._open()
160176
self._check_writable()
161177
assert isinstance(key, str)
162178
if not isinstance(value, Buffer):
163179
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
164180
path = self.root / key
165-
await to_thread(_put, path, value)
181+
await to_thread(_put, path, value, start=None, exclusive=exclusive)
166182

167183
async def set_partial_values(
168184
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]

src/zarr/store/logging.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING, Self
99

1010
from zarr.abc.store import AccessMode, ByteRangeRequest, Store
11+
from zarr.core.buffer import Buffer
1112

1213
if TYPE_CHECKING:
1314
from collections.abc import AsyncGenerator, Generator, Iterable
@@ -149,6 +150,10 @@ async def set(self, key: str, value: Buffer) -> None:
149150
with self.log():
150151
return await self._store.set(key=key, value=value)
151152

153+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
154+
with self.log():
155+
return await self._store.set_if_not_exists(key=key, value=default)
156+
152157
async def delete(self, key: str) -> None:
153158
with self.log():
154159
return await self._store.delete(key=key)

src/zarr/store/memory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ async def exists(self, key: str) -> bool:
8888
return key in self._store_dict
8989

9090
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
91-
if not self._is_open:
92-
await self._open()
9391
self._check_writable()
92+
await self._ensure_open()
9493
assert isinstance(key, str)
9594
if not isinstance(value, Buffer):
9695
raise TypeError(f"Expected Buffer. Got {type(value)}.")
@@ -102,6 +101,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
102101
else:
103102
self._store_dict[key] = value
104103

104+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
105+
self._check_writable()
106+
await self._ensure_open()
107+
self._store_dict.setdefault(key, default)
108+
105109
async def delete(self, key: str) -> None:
106110
self._check_writable()
107111
try:

src/zarr/store/remote.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import fsspec
66

77
from zarr.abc.store import ByteRangeRequest, Store
8+
from zarr.core.buffer import Buffer
89
from zarr.store.common import _dereference_path
910

1011
if TYPE_CHECKING:

src/zarr/store/zip.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ async def set(self, key: str, value: Buffer) -> None:
191191
async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
192192
raise NotImplementedError
193193

194+
async def set_if_not_exists(self, key: str, default: Buffer) -> None:
195+
self._check_writable()
196+
with self._lock:
197+
members = self._zf.namelist()
198+
if key not in members:
199+
self._set(key, default)
200+
194201
async def delete(self, key: str) -> None:
195202
raise NotImplementedError
196203

0 commit comments

Comments
 (0)