From d42e0e67fe3ca1a429156bc06982a08a7218b9e2 Mon Sep 17 00:00:00 2001 From: Aniket Singh Rawat Date: Mon, 13 Jan 2025 03:10:14 +0530 Subject: [PATCH 1/3] add enter and exit methods to groups. --- src/zarr/core/group.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 57d9c5cd8d..bca06a0207 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict from dataclasses import asdict, dataclass, field, fields, replace -from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload +from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload, Self import numpy as np import numpy.typing as npt @@ -1752,6 +1752,12 @@ def open( obj = sync(AsyncGroup.open(store, zarr_format=zarr_format)) return cls(obj) + def __enter__(self) -> Self: + return self + + def __exit__(self) -> None: # noqa: PYI036 + self.store.close() + def __getitem__(self, path: str) -> Array | Group: """Obtain a group member. From d8f7d686a15e3a62e87fb29f8d2288a08f4118ad Mon Sep 17 00:00:00 2001 From: Aniket Singh Rawat Date: Tue, 14 Jan 2025 02:25:22 +0530 Subject: [PATCH 2/3] add test --- src/zarr/core/group.py | 7 +++++-- tests/test_group.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bca06a0207..8dd3b19e54 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict from dataclasses import asdict, dataclass, field, fields, replace -from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload, Self +from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload import numpy as np import numpy.typing as npt @@ -55,6 +55,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable, Iterator + from types import TracebackType from typing import Any from zarr.core.array_spec import ArrayConfig, ArrayConfigLike @@ -1755,7 +1756,9 @@ def open( def __enter__(self) -> Self: return self - def __exit__(self) -> None: # noqa: PYI036 + def __exit__( + self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None + ) -> None: self.store.close() def __getitem__(self, path: str) -> Array | Group: diff --git a/tests/test_group.py b/tests/test_group.py index 788e81e603..da3c83e8c0 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -428,6 +428,28 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None: assert len(group) == 0 +def test_group_with_context_manager(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None: + spath = StorePath(store) + + # attempt to open a group that does not exist. + with pytest.raises(FileNotFoundError): + with Group.open(store) as store: + pass + + attrs = {"path": "foo"} + + with Group.from_store( + store, attributes=attrs, zarr_format=zarr_format, overwrite=overwrite + ) as group: + assert store._is_open + assert group.attrs == attrs + assert group.metadata.zarr_format == zarr_format + assert group.store_path == spath + + # Check if store was closed after exit. + assert not store._is_open + + def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ Test the `Group.__setitem__` method. From 3aee7b00112896df8f871cc73bfa4120c4612b79 Mon Sep 17 00:00:00 2001 From: Aniket Singh Rawat Date: Fri, 7 Mar 2025 22:54:27 +0530 Subject: [PATCH 3/3] use public apis --- src/zarr/core/group.py | 3 ++- tests/test_group.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ea25300e63..5f8d6e3186 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -63,8 +63,9 @@ Coroutine, Generator, Iterable, - Iterator + Iterator, ) + from types import TracebackType from typing import Any from zarr.core.array_spec import ArrayConfig, ArrayConfigLike diff --git a/tests/test_group.py b/tests/test_group.py index e37c167a7d..c598c46d38 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -455,12 +455,12 @@ def test_group_with_context_manager(store: Store, zarr_format: ZarrFormat, overw # attempt to open a group that does not exist. with pytest.raises(FileNotFoundError): - with Group.open(store) as store: + with zarr.open_group(store, mode="r") as group: pass attrs = {"path": "foo"} - with Group.from_store( + with zarr.create_group( store, attributes=attrs, zarr_format=zarr_format, overwrite=overwrite ) as group: assert store._is_open