Skip to content

Commit ed67cc3

Browse files
committed
arrays proxy
1 parent 77115ff commit ed67cc3

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

src/zarr/core/group.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,36 @@ def flatten(
294294
return metadata
295295

296296

297+
class ArraysProxy:
298+
"""
299+
Proxy for arrays in a group.
300+
301+
Used to implement the `Group.arrays` property
302+
"""
303+
304+
def __init__(self, group: Group) -> None:
305+
self._group = group
306+
307+
def __getitem__(self, key: str) -> Array:
308+
obj = self._group[key]
309+
if isinstance(obj, Array):
310+
return obj
311+
raise KeyError(key)
312+
313+
def __setitem__(self, key: str, value: npt.ArrayLike) -> None:
314+
"""
315+
Set an array in the group.
316+
"""
317+
self._group._sync(self._group._async_group.set_array(key, value))
318+
319+
def __iter__(self) -> Generator[tuple[str, Array], None]:
320+
for name, async_array in self._group._sync_iter(self._group._async_group.arrays()):
321+
yield name, Array(async_array)
322+
323+
def __call__(self) -> Generator[tuple[str, Array], None]:
324+
return iter(self)
325+
326+
297327
@dataclass(frozen=True)
298328
class GroupMetadata(Metadata):
299329
attributes: dict[str, Any] = field(default_factory=dict)
@@ -596,7 +626,16 @@ def from_dict(
596626
store_path=store_path,
597627
)
598628

599-
async def setitem(self, key: str, value: Any) -> None:
629+
async def set_array(self, key: str, value: Any) -> None:
630+
"""fastpath for creating a new array
631+
632+
Parameters
633+
----------
634+
key : str
635+
Array name
636+
value : array-like
637+
Array data
638+
"""
600639
path = self.store_path / key
601640
await async_api.save_array(
602641
store=path, arr=value, zarr_format=self.metadata.zarr_format, exists_ok=True
@@ -1374,9 +1413,14 @@ def __iter__(self) -> Iterator[str]:
13741413
def __len__(self) -> int:
13751414
return self.nmembers()
13761415

1416+
@deprecated("Use Group.arrays setter instead.")
13771417
def __setitem__(self, key: str, value: Any) -> None:
1378-
"""Create a new array"""
1379-
self._sync(self._async_group.setitem(key, value))
1418+
"""Create a new array
1419+
1420+
.. deprecated:: 3.0.0
1421+
Use Group.arrays.setter instead.
1422+
"""
1423+
self._sync(self._async_group.set_array(key, value))
13801424

13811425
def __repr__(self) -> str:
13821426
return f"<Group {self.store_path}>"
@@ -1473,9 +1517,9 @@ def group_values(self) -> Generator[Group, None]:
14731517
for _, group in self.groups():
14741518
yield group
14751519

1476-
def arrays(self) -> Generator[tuple[str, Array], None]:
1477-
for name, async_array in self._sync_iter(self._async_group.arrays()):
1478-
yield name, Array(async_array)
1520+
@property
1521+
def arrays(self) -> ArraysProxy:
1522+
return ArraysProxy(self)
14791523

14801524
def array_keys(self) -> Generator[str, None]:
14811525
for name, _ in self.arrays():

tests/v3/test_group.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
392392
"""
393393
group = Group.from_store(store, zarr_format=zarr_format)
394394
arr = np.ones((2, 4))
395-
group["key"] = arr
395+
with pytest.warns(DeprecationWarning):
396+
group["key"] = arr
396397
assert group["key"].shape == (2, 4)
397398
np.testing.assert_array_equal(group["key"][:], arr)
398399

@@ -405,7 +406,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
405406

406407
# overwrite with another array
407408
arr = np.zeros((3, 5))
408-
group[key] = arr
409+
with pytest.warns(DeprecationWarning):
410+
group[key] = arr
409411
assert group[key].shape == (3, 5)
410412
np.testing.assert_array_equal(group[key], arr)
411413

@@ -416,6 +418,31 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
416418
# assert group["key"][:] == 1
417419

418420

421+
def test_group_arrays_setter(store: Store, zarr_format: ZarrFormat) -> None:
422+
"""
423+
Test the `Group.__setitem__` method.
424+
"""
425+
group = Group.from_store(store, zarr_format=zarr_format)
426+
arr = np.ones((2, 4))
427+
group.arrays["key"] = arr
428+
assert group["key"].shape == (2, 4)
429+
np.testing.assert_array_equal(group["key"][:], arr)
430+
431+
if store.supports_deletes:
432+
key = "key"
433+
else:
434+
# overwriting with another array requires deletes
435+
# for stores that don't support this, we just use a new key
436+
key = "key2"
437+
438+
# overwrite with another array
439+
arr = np.zeros((3, 5))
440+
with pytest.warns(DeprecationWarning):
441+
group[key] = arr
442+
assert group[key].shape == (3, 5)
443+
np.testing.assert_array_equal(group[key], arr)
444+
445+
419446
def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None:
420447
"""
421448
Test the `Group.__contains__` method

0 commit comments

Comments
 (0)