Skip to content

Commit d8f9b01

Browse files
committed
Ensure parents are created when creating a node
This updates our Array and Group creation methods to ensure that parents implicitly defined through a nested path are also created. To accomplish this semi-safely and efficiently, we require a new setdefulat method on the Store class.
1 parent f0443db commit d8f9b01

File tree

12 files changed

+194
-11
lines changed

12 files changed

+194
-11
lines changed

src/zarr/abc/store.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ async def set(self, key: str, value: Buffer) -> None:
173173
"""
174174
...
175175

176+
@abstractmethod
177+
async def setdefault(self, key: str, default: Buffer) -> None:
178+
"""
179+
Store a key with a value of ``default`` if the key is not already present.
180+
181+
Unlike MutableMapping.default, this method does not provide any way to
182+
know whether ``default`` was actually set.
183+
"""
184+
...
185+
176186
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
177187
"""
178188
Insert multiple (key, value) pairs into storage.
@@ -298,9 +308,18 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
298308

299309
async def delete(self) -> None: ...
300310

311+
async def setdefault(self, default: Buffer) -> None: ...
312+
301313

302314
async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
303315
if value is None:
304316
await byte_setter.delete()
305317
else:
306318
await byte_setter.set(value)
319+
320+
321+
async def setdefault(byte_setter: ByteSetter, value: Buffer | None) -> None:
322+
if value is None:
323+
await byte_setter.delete()
324+
else:
325+
await byte_setter.set(value)

src/zarr/codecs/sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
9898
async def delete(self) -> None:
9999
del self.shard_dict[self.chunk_coords]
100100

101+
async def setdefault(self, default: Buffer) -> None:
102+
self.shard_dict.setdefault(self.chunk_coords, default)
103+
101104

102105
class _ShardIndex(NamedTuple):
103106
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)

src/zarr/core/array.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from zarr.codecs import BytesCodec
1414
from zarr.codecs._v2 import V2Compressor, V2Filters
1515
from zarr.core.attributes import Attributes
16-
from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
16+
from zarr.core.buffer import (
17+
BufferPrototype,
18+
NDArrayLike,
19+
NDBuffer,
20+
default_buffer_prototype,
21+
)
1722
from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks
1823
from zarr.core.chunk_key_encodings import (
1924
ChunkKeyEncoding,
@@ -71,6 +76,7 @@
7176
from collections.abc import Iterable, Iterator, Sequence
7277

7378
from zarr.abc.codec import Codec, CodecPipeline
79+
from zarr.core.group import AsyncGroup
7480
from zarr.core.metadata.common import ArrayMetadata
7581

7682
# Array and AsyncArray are defined in the base ``zarr`` namespace
@@ -276,7 +282,7 @@ async def _create_v3(
276282
)
277283

278284
array = cls(metadata=metadata, store_path=store_path)
279-
await array._save_metadata(metadata)
285+
await array._save_metadata(metadata, ensure_parents=True)
280286
return array
281287

282288
@classmethod
@@ -315,7 +321,7 @@ async def _create_v2(
315321
attributes=attributes,
316322
)
317323
array = cls(metadata=metadata, store_path=store_path)
318-
await array._save_metadata(metadata)
324+
await array._save_metadata(metadata, ensure_parents=True)
319325
return array
320326

321327
@classmethod
@@ -603,9 +609,24 @@ async def getitem(
603609
)
604610
return await self._get_selection(indexer, prototype=prototype)
605611

606-
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
612+
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
607613
to_save = metadata.to_buffer_dict(default_buffer_prototype())
608614
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
615+
616+
if ensure_parents:
617+
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediates.
618+
parents = _build_parents(self)
619+
620+
for parent in parents:
621+
awaitables.extend(
622+
[
623+
(parent.store_path / key).setdefault(value)
624+
for key, value in parent.metadata.to_buffer_dict(
625+
default_buffer_prototype()
626+
).items()
627+
]
628+
)
629+
609630
await gather(*awaitables)
610631

611632
async def _set_selection(
@@ -2336,3 +2357,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
23362357
out.append(chunk_key)
23372358

23382359
return tuple(out)
2360+
2361+
2362+
def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
2363+
from zarr.core.group import AsyncGroup, GroupMetadata
2364+
2365+
required_parts = node.store_path.path.split("/")[:-1]
2366+
parents = []
2367+
2368+
for i, part in enumerate(required_parts):
2369+
path = "/".join(required_parts[:i] + [part])
2370+
parents.append(
2371+
AsyncGroup(
2372+
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
2373+
store_path=StorePath(store=node.store_path.store, path=path),
2374+
)
2375+
)
2376+
2377+
return parents

src/zarr/core/group.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import zarr.api.asynchronous as async_api
1414
from zarr.abc.metadata import Metadata
1515
from zarr.abc.store import Store, set_or_delete
16-
from zarr.core.array import Array, AsyncArray
16+
from zarr.core.array import Array, AsyncArray, _build_parents
1717
from zarr.core.attributes import Attributes
1818
from zarr.core.buffer import default_buffer_prototype
1919
from zarr.core.common import (
@@ -144,7 +144,7 @@ async def from_store(
144144
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
145145
store_path=store_path,
146146
)
147-
await group._save_metadata()
147+
await group._save_metadata(ensure_parents=True)
148148
return group
149149

150150
@classmethod
@@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None:
279279
else:
280280
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
281281

282-
async def _save_metadata(self) -> None:
282+
async def _save_metadata(self, ensure_parents: bool = False) -> None:
283283
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
284284
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
285+
286+
if ensure_parents:
287+
parents = _build_parents(self)
288+
for parent in parents:
289+
awaitables.extend(
290+
[
291+
(parent.store_path / key).setdefault(value)
292+
for key, value in parent.metadata.to_buffer_dict(
293+
default_buffer_prototype()
294+
).items()
295+
]
296+
)
297+
285298
await asyncio.gather(*awaitables)
286299

287300
@property

src/zarr/store/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
5151
async def delete(self) -> None:
5252
await self.store.delete(self.path)
5353

54+
async def setdefault(self, default: Buffer) -> None:
55+
await self.store.setdefault(self.path, default)
56+
5457
async def exists(self) -> bool:
5558
return await self.store.exists(self.path)
5659

src/zarr/store/local.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _put(
6060
path: Path,
6161
value: Buffer,
6262
start: int | None = None,
63+
exclusive: bool = False,
6364
) -> int | None:
6465
path.parent.mkdir(parents=True, exist_ok=True)
6566
if start is not None:
@@ -68,7 +69,13 @@ def _put(
6869
f.write(value.as_numpy_array().tobytes())
6970
return None
7071
else:
71-
return path.write_bytes(value.as_numpy_array().tobytes())
72+
view = memoryview(value.as_numpy_array().tobytes())
73+
if exclusive:
74+
mode = "xb"
75+
else:
76+
mode = "wb"
77+
with path.open(mode=mode) as f:
78+
return f.write(view)
7279

7380

7481
class LocalStore(Store):
@@ -152,14 +159,23 @@ async def get_partial_values(
152159
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
153160

154161
async def set(self, key: str, value: Buffer) -> None:
162+
return await self._set(key, value)
163+
164+
async def setdefault(self, key: str, value: Buffer) -> None:
165+
try:
166+
return await self._set(key, value, exclusive=True)
167+
except FileExistsError:
168+
pass
169+
170+
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
155171
if not self._is_open:
156172
await self._open()
157173
self._check_writable()
158174
assert isinstance(key, str)
159175
if not isinstance(value, Buffer):
160176
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
161177
path = self.root / key
162-
await to_thread(_put, path, value)
178+
await to_thread(_put, path, value, start=None, exclusive=exclusive)
163179

164180
async def set_partial_values(
165181
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]

src/zarr/store/memory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool:
8585
return key in self._store_dict
8686

8787
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
88-
if not self._is_open:
89-
await self._open()
9088
self._check_writable()
89+
await self._ensure_open()
9190
assert isinstance(key, str)
9291
if not isinstance(value, Buffer):
9392
raise TypeError(f"Expected Buffer. Got {type(value)}.")
@@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
9998
else:
10099
self._store_dict[key] = value
101100

101+
async def setdefault(self, key: str, default: Buffer) -> None:
102+
self._check_writable()
103+
await self._ensure_open()
104+
self._store_dict.setdefault(key, default)
105+
102106
async def delete(self, key: str) -> None:
103107
self._check_writable()
104108
try:

src/zarr/store/remote.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import fsspec
66

77
from zarr.abc.store import ByteRangeRequest, Store
8+
from zarr.core.buffer import Buffer
89
from zarr.store.common import _dereference_path
910

1011
if TYPE_CHECKING:
@@ -208,6 +209,11 @@ async def set_partial_values(
208209
) -> None:
209210
raise NotImplementedError
210211

212+
async def setdefault(self, key: str, default: Buffer) -> None:
213+
# this isn't safe for concurrent writers, but that's probably unavoidable.
214+
if not await self.fs._exists(_dereference_path(self.path, key)):
215+
await self.set(key, default)
216+
211217
async def list(self) -> AsyncGenerator[str, None]:
212218
allfiles = await self.fs._find(self.path, detail=False, withdirs=False)
213219
for onefile in (a.replace(self.path + "/", "") for a in allfiles):

src/zarr/store/zip.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ async def set(self, key: str, value: Buffer) -> None:
191191
async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
192192
raise NotImplementedError
193193

194+
async def setdefault(self, key: str, default: Buffer) -> None:
195+
self._check_writable()
196+
with self._lock:
197+
members = self._zf.namelist()
198+
if key not in members:
199+
self._set(key, default)
200+
194201
async def delete(self, key: str) -> None:
195202
raise NotImplementedError
196203

src/zarr/testing/store.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,20 @@ async def test_list_dir(self, store: S) -> None:
273273

274274
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
275275
assert sorted(keys_expected) == sorted(keys_observed)
276+
277+
async def test_setdefault(self, store: S) -> None:
278+
key = "k"
279+
data_buf = self.buffer_cls.from_bytes(b"0000")
280+
self.set(store, key, data_buf)
281+
282+
new = self.buffer_cls.from_bytes(b"1111")
283+
await store.setdefault("k", new) # no error
284+
285+
result = await store.get(key, default_buffer_prototype())
286+
assert result == data_buf
287+
288+
await store.setdefault("k2", new) # no error
289+
await store.get("k2", default_buffer_prototype())
290+
291+
result = await store.get("k2", default_buffer_prototype())
292+
assert result == new

0 commit comments

Comments
 (0)