Skip to content

Commit 0cdc29a

Browse files
authored
Merge branch 'v3' into fix/remote-store-empty-speedup-walk
2 parents d3e8733 + 8f4ef26 commit 0cdc29a

File tree

7 files changed

+91
-24
lines changed

7 files changed

+91
-24
lines changed

src/zarr/core/group.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from dataclasses import asdict, dataclass, field, fields, replace
7-
from typing import TYPE_CHECKING, Literal, cast, overload
7+
from typing import TYPE_CHECKING, Literal, TypeVar, cast, overload
88

99
import numpy as np
1010
import numpy.typing as npt
@@ -42,6 +42,8 @@
4242

4343
logger = logging.getLogger("zarr.group")
4444

45+
DefaultT = TypeVar("DefaultT")
46+
4547

4648
def parse_zarr_format(data: Any) -> ZarrFormat:
4749
if data in (2, 3):
@@ -290,6 +292,28 @@ async def delitem(self, key: str) -> None:
290292
else:
291293
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
292294

295+
async def get(
296+
self, key: str, default: DefaultT | None = None
297+
) -> AsyncArray | AsyncGroup | DefaultT | None:
298+
"""Obtain a group member, returning default if not found.
299+
300+
Parameters
301+
----------
302+
key : string
303+
Group member name.
304+
default : object
305+
Default value to return if key is not found (default: None).
306+
307+
Returns
308+
-------
309+
object
310+
Group member (AsyncArray or AsyncGroup) or default if not found.
311+
"""
312+
try:
313+
return await self.getitem(key)
314+
except KeyError:
315+
return default
316+
293317
async def _save_metadata(self, ensure_parents: bool = False) -> None:
294318
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
295319
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
@@ -828,6 +852,26 @@ def __getitem__(self, path: str) -> Array | Group:
828852
else:
829853
return Group(obj)
830854

855+
def get(self, path: str, default: DefaultT | None = None) -> Array | Group | DefaultT | None:
856+
"""Obtain a group member, returning default if not found.
857+
858+
Parameters
859+
----------
860+
key : string
861+
Group member name.
862+
default : object
863+
Default value to return if key is not found (default: None).
864+
865+
Returns
866+
-------
867+
object
868+
Group member (Array or Group) or default if not found.
869+
"""
870+
try:
871+
return self[path]
872+
except KeyError:
873+
return default
874+
831875
def __delitem__(self, key: str) -> None:
832876
self._sync(self._async_group.delitem(key))
833877

src/zarr/testing/store.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ class StoreTests(Generic[S, B]):
2121
store_cls: type[S]
2222
buffer_cls: type[B]
2323

24-
def set(self, store: S, key: str, value: Buffer) -> None:
24+
async def set(self, store: S, key: str, value: Buffer) -> None:
2525
"""
2626
Insert a value into a storage backend, with a specific key.
2727
This should not not use any store methods. Bypassing the store methods allows them to be
2828
tested.
2929
"""
3030
raise NotImplementedError
3131

32-
def get(self, store: S, key: str) -> Buffer:
32+
async def get(self, store: S, key: str) -> Buffer:
3333
"""
3434
Retrieve a value from a storage backend, by key.
3535
This should not not use any store methods. Bypassing the store methods allows them to be
@@ -106,7 +106,7 @@ async def test_get(
106106
Ensure that data can be read from the store using the store.get method.
107107
"""
108108
data_buf = self.buffer_cls.from_bytes(data)
109-
self.set(store, key, data_buf)
109+
await self.set(store, key, data_buf)
110110
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
111111
start, length = _normalize_interval_index(data_buf, interval=byte_range)
112112
expected = data_buf[start : start + length]
@@ -119,7 +119,7 @@ async def test_get_many(self, store: S) -> None:
119119
keys = tuple(map(str, range(10)))
120120
values = tuple(f"{k}".encode() for k in keys)
121121
for k, v in zip(keys, values, strict=False):
122-
self.set(store, k, self.buffer_cls.from_bytes(v))
122+
await self.set(store, k, self.buffer_cls.from_bytes(v))
123123
observed_buffers = await _collect_aiterator(
124124
store._get_many(
125125
zip(
@@ -143,7 +143,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
143143
assert not store.mode.readonly
144144
data_buf = self.buffer_cls.from_bytes(data)
145145
await store.set(key, data_buf)
146-
observed = self.get(store, key)
146+
observed = await self.get(store, key)
147147
assert_bytes_equal(observed, data_buf)
148148

149149
async def test_set_many(self, store: S) -> None:
@@ -156,7 +156,7 @@ async def test_set_many(self, store: S) -> None:
156156
store_dict = dict(zip(keys, data_buf, strict=True))
157157
await store._set_many(store_dict.items())
158158
for k, v in store_dict.items():
159-
assert self.get(store, k).to_bytes() == v.to_bytes()
159+
assert (await self.get(store, k)).to_bytes() == v.to_bytes()
160160

161161
@pytest.mark.parametrize(
162162
"key_ranges",
@@ -172,7 +172,7 @@ async def test_get_partial_values(
172172
) -> None:
173173
# put all of the data
174174
for key, _ in key_ranges:
175-
self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
175+
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
176176

177177
# read back just part of it
178178
observed_maybe = await store.get_partial_values(
@@ -211,11 +211,15 @@ async def test_delete(self, store: S) -> None:
211211

212212
async def test_empty(self, store: S) -> None:
213213
assert await store.empty()
214-
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
214+
await self.set(
215+
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
216+
)
215217
assert not await store.empty()
216218

217219
async def test_clear(self, store: S) -> None:
218-
self.set(store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")))
220+
await self.set(
221+
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
222+
)
219223
await store.clear()
220224
assert await store.empty()
221225

@@ -277,8 +281,8 @@ async def test_list_dir(self, store: S) -> None:
277281

278282
async def test_with_mode(self, store: S) -> None:
279283
data = b"0000"
280-
self.set(store, "key", self.buffer_cls.from_bytes(data))
281-
assert self.get(store, "key").to_bytes() == data
284+
await self.set(store, "key", self.buffer_cls.from_bytes(data))
285+
assert (await self.get(store, "key")).to_bytes() == data
282286

283287
for mode in ["r", "a"]:
284288
mode = cast(AccessModeLiteral, mode)
@@ -294,7 +298,7 @@ async def test_with_mode(self, store: S) -> None:
294298
assert result.to_bytes() == data
295299

296300
# writes to original after with_mode is visible
297-
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
301+
await self.set(store, "key-2", self.buffer_cls.from_bytes(data))
298302
result = await clone.get("key-2", default_buffer_prototype())
299303
assert result is not None
300304
assert result.to_bytes() == data
@@ -313,7 +317,7 @@ async def test_with_mode(self, store: S) -> None:
313317
async def test_set_if_not_exists(self, store: S) -> None:
314318
key = "k"
315319
data_buf = self.buffer_cls.from_bytes(b"0000")
316-
self.set(store, key, data_buf)
320+
await self.set(store, key, data_buf)
317321

318322
new = self.buffer_cls.from_bytes(b"1111")
319323
await store.set_if_not_exists("k", new) # no error

tests/v3/test_group.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,25 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat) -> None:
292292
group["nope"]
293293

294294

295+
def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
296+
group = Group.from_store(store, zarr_format=zarr_format)
297+
298+
# default behavior
299+
result = group.get("subgroup")
300+
assert result is None
301+
302+
# custom default
303+
result = group.get("subgroup", 8)
304+
assert result == 8
305+
306+
# now with a group
307+
subgroup = group.require_group("subgroup")
308+
subgroup.attrs["foo"] = "bar"
309+
310+
result = group.get("subgroup", 8)
311+
assert result.attrs["foo"] == "bar"
312+
313+
295314
def test_group_delitem(store: Store, zarr_format: ZarrFormat) -> None:
296315
"""
297316
Test the `Group.__delitem__` method.

tests/v3/test_store/test_local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]):
1111
store_cls = LocalStore
1212
buffer_cls = cpu.Buffer
1313

14-
def get(self, store: LocalStore, key: str) -> Buffer:
14+
async def get(self, store: LocalStore, key: str) -> Buffer:
1515
return self.buffer_cls.from_bytes((store.root / key).read_bytes())
1616

17-
def set(self, store: LocalStore, key: str, value: Buffer) -> None:
17+
async def set(self, store: LocalStore, key: str, value: Buffer) -> None:
1818
parent = (store.root / key).parent
1919
if not parent.exists():
2020
parent.mkdir(parents=True)

tests/v3/test_store/test_memory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
1212
store_cls = MemoryStore
1313
buffer_cls = cpu.Buffer
1414

15-
def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
15+
async def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
1616
store._store_dict[key] = value
1717

18-
def get(self, store: MemoryStore, key: str) -> Buffer:
18+
async def get(self, store: MemoryStore, key: str) -> Buffer:
1919
return store._store_dict[key]
2020

2121
@pytest.fixture(params=[None, True])
@@ -52,10 +52,10 @@ class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]):
5252
store_cls = GpuMemoryStore
5353
buffer_cls = gpu.Buffer
5454

55-
def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
55+
async def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
5656
store._store_dict[key] = value
5757

58-
def get(self, store: MemoryStore, key: str) -> Buffer:
58+
async def get(self, store: MemoryStore, key: str) -> Buffer:
5959
return store._store_dict[key]
6060

6161
@pytest.fixture(params=[None, True])

tests/v3/test_store/test_remote.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
117117
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
118118
return self.store_cls(**store_kwargs)
119119

120-
def get(self, store: RemoteStore, key: str) -> Buffer:
120+
async def get(self, store: RemoteStore, key: str) -> Buffer:
121121
# make a new, synchronous instance of the filesystem because this test is run in sync code
122122
new_fs = fsspec.filesystem(
123123
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False
124124
)
125125
return self.buffer_cls.from_bytes(new_fs.cat(f"{store.path}/{key}"))
126126

127-
def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
127+
async def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
128128
# make a new, synchronous instance of the filesystem because this test is run in sync code
129129
new_fs = fsspec.filesystem(
130130
"s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False

tests/v3/test_store/test_zip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
2929

3030
return {"path": temp_path, "mode": "w"}
3131

32-
def get(self, store: ZipStore, key: str) -> Buffer:
32+
async def get(self, store: ZipStore, key: str) -> Buffer:
3333
return store._get(key, prototype=default_buffer_prototype())
3434

35-
def set(self, store: ZipStore, key: str, value: Buffer) -> None:
35+
async def set(self, store: ZipStore, key: str, value: Buffer) -> None:
3636
return store._set(key, value)
3737

3838
def test_store_mode(self, store: ZipStore, store_kwargs: dict[str, Any]) -> None:

0 commit comments

Comments
 (0)