Skip to content

Commit 16675fa

Browse files
committed
Add refresh_attributes() and implement cache_attrs for Group, Array
Should resolve #3178
1 parent 111d765 commit 16675fa

File tree

5 files changed

+184
-10
lines changed

5 files changed

+184
-10
lines changed

src/zarr/api/synchronous.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,6 @@ def group(
441441
store=store,
442442
overwrite=overwrite,
443443
chunk_store=chunk_store,
444-
cache_attrs=cache_attrs,
445444
synchronizer=synchronizer,
446445
path=path,
447446
zarr_version=zarr_version,
@@ -450,7 +449,8 @@ def group(
450449
attributes=attributes,
451450
storage_options=storage_options,
452451
)
453-
)
452+
),
453+
cache_attrs=cache_attrs,
454454
)
455455

456456

@@ -536,7 +536,6 @@ def open_group(
536536
async_api.open_group(
537537
store=store,
538538
mode=mode,
539-
cache_attrs=cache_attrs,
540539
synchronizer=synchronizer,
541540
path=path,
542541
chunk_store=chunk_store,
@@ -547,7 +546,8 @@ def open_group(
547546
attributes=attributes,
548547
use_consolidated=use_consolidated,
549548
)
550-
)
549+
),
550+
cache_attrs=cache_attrs,
551551
)
552552

553553

@@ -559,6 +559,7 @@ def create_group(
559559
overwrite: bool = False,
560560
attributes: dict[str, Any] | None = None,
561561
storage_options: dict[str, Any] | None = None,
562+
cache_attrs: bool | None = None,
562563
) -> Group:
563564
"""Create a group.
564565
@@ -595,7 +596,8 @@ def create_group(
595596
zarr_format=zarr_format,
596597
attributes=attributes,
597598
)
598-
)
599+
),
600+
cache_attrs=cache_attrs,
599601
)
600602

601603

@@ -730,7 +732,6 @@ def create(
730732
chunk_store=chunk_store,
731733
filters=filters,
732734
cache_metadata=cache_metadata,
733-
cache_attrs=cache_attrs,
734735
read_only=read_only,
735736
object_codec=object_codec,
736737
dimension_separator=dimension_separator,
@@ -747,7 +748,8 @@ def create(
747748
config=config,
748749
**kwargs,
749750
)
750-
)
751+
),
752+
cache_attrs=cache_attrs,
751753
)
752754

753755

@@ -773,6 +775,7 @@ def create_array(
773775
overwrite: bool = False,
774776
config: ArrayConfigLike | None = None,
775777
write_data: bool = True,
778+
cache_attrs: bool | None = None,
776779
) -> Array:
777780
"""Create an array.
778781
@@ -872,6 +875,10 @@ def create_array(
872875
then ``write_data`` determines whether the values in that array-like object should be
873876
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
874877
array will be left empty.
878+
cache_attrs : bool, optional
879+
If True (default), user attributes will be cached for attribute read
880+
operations. If False, user attributes are reloaded from the store prior
881+
to all attribute read operations.
875882
876883
Returns
877884
-------
@@ -914,7 +921,8 @@ def create_array(
914921
config=config,
915922
write_data=write_data,
916923
)
917-
)
924+
),
925+
cache_attrs=cache_attrs,
918926
)
919927

920928

src/zarr/core/array.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,22 @@ async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self:
16771677

16781678
return self
16791679

1680+
async def refresh_attributes(self) -> Self:
1681+
"""Reload the attributes of this array from the store.
1682+
1683+
Returns
1684+
-------
1685+
AsyncArray
1686+
The array updated with the newest attributes from storage."""
1687+
1688+
metadata = await get_array_metadata(self.store_path, self.metadata.zarr_format)
1689+
reparsed_metadata = parse_array_metadata(metadata)
1690+
1691+
self.metadata.attributes.clear()
1692+
self.metadata.attributes.update(reparsed_metadata.attributes)
1693+
1694+
return self
1695+
16801696
def __repr__(self) -> str:
16811697
return f"<AsyncArray {self.store_path} shape={self.shape} dtype={self.dtype}>"
16821698

@@ -1768,6 +1784,7 @@ class Array:
17681784
"""
17691785

17701786
_async_array: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
1787+
cache_attrs: bool | None = field(default=None)
17711788

17721789
@classmethod
17731790
@deprecated("Use zarr.create_array instead.")
@@ -2105,6 +2122,8 @@ def attrs(self) -> Attributes:
21052122
-----
21062123
Note that attribute values must be JSON serializable.
21072124
"""
2125+
if self.cache_attrs is False:
2126+
self.refresh_attributes()
21082127
return Attributes(self)
21092128

21102129
@property
@@ -3703,6 +3722,19 @@ def update_attributes(self, new_attributes: dict[str, JSON]) -> Array:
37033722
_new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array)
37043723
return type(self)(_new_array)
37053724

3725+
def refresh_attributes(self) -> Array:
3726+
"""Reload the attributes of this array from the store.
3727+
3728+
Returns
3729+
-------
3730+
Array
3731+
The array with the updated attributes."""
3732+
# TODO: remove this cast when type inference improves
3733+
new_array = sync(self._async_array.refresh_attributes())
3734+
# TODO: remove this cast when type inference improves
3735+
_new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array)
3736+
return type(self)(_new_array)
3737+
37063738
def __repr__(self) -> str:
37073739
return f"<Array {self.store_path} shape={self.shape} dtype={self.dtype}>"
37083740

src/zarr/core/group.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,26 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
12731273

12741274
return self
12751275

1276+
async def refresh_attributes(self) -> AsyncGroup:
1277+
"""Reload the attributes of this group from the store.
1278+
1279+
Returns
1280+
-------
1281+
self : AsyncGroup
1282+
The group updated with the newest attributes from storage.
1283+
"""
1284+
1285+
reparsed_metadata = await _read_group_metadata(
1286+
store=self.store_path.store,
1287+
path=self.store_path.path,
1288+
zarr_format=self.metadata.zarr_format,
1289+
)
1290+
1291+
self.metadata.attributes.clear()
1292+
self.metadata.attributes.update(reparsed_metadata.attributes)
1293+
1294+
return self
1295+
12761296
def __repr__(self) -> str:
12771297
return f"<AsyncGroup {self.store_path}>"
12781298

@@ -1774,6 +1794,7 @@ class Group(SyncMixin):
17741794
"""
17751795

17761796
_async_group: AsyncGroup
1797+
cache_attrs: bool | None = field(default=None)
17771798

17781799
@classmethod
17791800
def from_store(
@@ -1783,6 +1804,7 @@ def from_store(
17831804
attributes: dict[str, Any] | None = None,
17841805
zarr_format: ZarrFormat = 3,
17851806
overwrite: bool = False,
1807+
cache_attrs: bool | None = None,
17861808
) -> Group:
17871809
"""Instantiate a group from an initialized store.
17881810
@@ -1796,6 +1818,10 @@ def from_store(
17961818
Zarr storage format version.
17971819
overwrite : bool, optional
17981820
If True, do not raise an error if the group already exists.
1821+
cache_attrs : bool, optional
1822+
If True (default), user attributes will be cached for attribute read
1823+
operations. If False, user attributes are reloaded from the store prior
1824+
to all attribute read operations.
17991825
18001826
Returns
18011827
-------
@@ -1816,13 +1842,15 @@ def from_store(
18161842
),
18171843
)
18181844

1819-
return cls(obj)
1845+
return cls(obj, cache_attrs=cache_attrs)
18201846

18211847
@classmethod
18221848
def open(
18231849
cls,
18241850
store: StoreLike,
18251851
zarr_format: ZarrFormat | None = 3,
1852+
*,
1853+
cache_attrs: bool | None = None,
18261854
) -> Group:
18271855
"""Open a group from an initialized store.
18281856
@@ -1832,14 +1860,18 @@ def open(
18321860
Store containing the Group.
18331861
zarr_format : {2, 3, None}, optional
18341862
Zarr storage format version.
1863+
cache_attrs : bool, optional
1864+
If True (default), user attributes will be cached for attribute read
1865+
operations. If False, user attributes are reloaded from the store prior
1866+
to all attribute read operations.
18351867
18361868
Returns
18371869
-------
18381870
Group
18391871
Group instantiated from the store.
18401872
"""
18411873
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
1842-
return cls(obj)
1874+
return cls(obj, cache_attrs=cache_attrs)
18431875

18441876
def __getitem__(self, path: str) -> Array | Group:
18451877
"""Obtain a group member.
@@ -2024,6 +2056,8 @@ def basename(self) -> str:
20242056
@property
20252057
def attrs(self) -> Attributes:
20262058
"""Attributes of this Group"""
2059+
if self.cache_attrs is False:
2060+
self.refresh_attributes()
20272061
return Attributes(self)
20282062

20292063
@property
@@ -2090,6 +2124,11 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
20902124
self._sync(self._async_group.update_attributes(new_attributes))
20912125
return self
20922126

2127+
def refresh_attributes(self) -> Group:
2128+
"""Reload the attributes of this Group from the store."""
2129+
self._sync(self._async_group.refresh_attributes())
2130+
return self
2131+
20932132
def nmembers(self, max_depth: int | None = 0) -> int:
20942133
"""Count the number of members in this group.
20952134

tests/test_array.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,68 @@ def test_update_attrs(zarr_format: ZarrFormat) -> None:
494494
assert arr2.attrs["foo"] == "bar"
495495

496496

497+
@pytest.mark.parametrize("zarr_format", [2, 3])
498+
def test_refresh_attrs(zarr_format: ZarrFormat) -> None:
499+
"""
500+
Test the behavior of `Array.refresh_attributes`
501+
"""
502+
store = MemoryStore()
503+
attrs: dict[str, JSON] = {"foo": 100}
504+
arr = zarr.create_array(
505+
store=store, shape=(5,), chunks=(5,), dtype="f8", attributes=attrs, zarr_format=zarr_format
506+
)
507+
assert arr.attrs.asdict() == attrs
508+
509+
new_attrs: dict[str, JSON] = {"bar": 50}
510+
arr2 = zarr.create_array(
511+
store=store,
512+
shape=(5,),
513+
chunks=(5,),
514+
dtype="f8",
515+
attributes=new_attrs,
516+
zarr_format=zarr_format,
517+
overwrite=True,
518+
)
519+
assert arr2.attrs.asdict() == new_attrs
520+
521+
assert arr.attrs.asdict() == attrs
522+
arr.refresh_attributes()
523+
assert arr.attrs.asdict() == new_attrs
524+
525+
526+
@pytest.mark.parametrize("zarr_format", [2, 3])
527+
def test_cache_attrs(zarr_format: ZarrFormat) -> None:
528+
"""
529+
Test the behavior of `Array.cache_attrs`
530+
"""
531+
store = MemoryStore()
532+
attrs: dict[str, JSON] = {"foo": 100}
533+
arr = zarr.create_array(
534+
store=store,
535+
shape=(5,),
536+
chunks=(5,),
537+
dtype="f8",
538+
attributes=attrs,
539+
zarr_format=zarr_format,
540+
cache_attrs=False,
541+
)
542+
assert arr.attrs.asdict() == attrs
543+
544+
new_attrs: dict[str, JSON] = {"bar": 50}
545+
arr2 = zarr.create_array(
546+
store=store,
547+
shape=(5,),
548+
chunks=(5,),
549+
dtype="f8",
550+
attributes=new_attrs,
551+
zarr_format=zarr_format,
552+
overwrite=True,
553+
)
554+
555+
assert arr2.attrs.asdict() == new_attrs
556+
assert arr.attrs.asdict() == new_attrs
557+
558+
497559
@pytest.mark.parametrize(("chunks", "shards"), [((2, 2), None), ((2, 2), (4, 4))])
498560
class TestInfo:
499561
def test_info_v2(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:

tests/test_group.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,39 @@ async def test_group_update_attributes_async(store: Store, zarr_format: ZarrForm
627627
assert new_group.attrs == new_attrs
628628

629629

630+
def test_group_refresh_attributes(store: Store, zarr_format: ZarrFormat) -> None:
631+
"""
632+
Test the behavior of `Group.refresh_attributes`
633+
"""
634+
attrs = {"foo": 100}
635+
group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs)
636+
assert group.attrs == attrs
637+
new_attrs = {"foo": 50}
638+
group2 = Group.open(store, zarr_format=zarr_format)
639+
group2.update_attributes(new_attrs)
640+
assert group2.attrs == new_attrs
641+
642+
assert group.attrs == attrs
643+
new_group = group.refresh_attributes()
644+
assert new_group.attrs == new_attrs
645+
646+
647+
def test_group_cache_attrs(store: Store, zarr_format: ZarrFormat) -> None:
648+
"""
649+
Test the behavior of `Group.cache_attrs`
650+
"""
651+
attrs = {"foo": 100}
652+
group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs, cache_attrs=False)
653+
assert group.attrs == attrs
654+
655+
new_attrs = {"foo": 50}
656+
group2 = Group.open(store, zarr_format=zarr_format)
657+
group2.update_attributes(new_attrs)
658+
assert group2.attrs == new_attrs
659+
660+
assert group.attrs == new_attrs
661+
662+
630663
@pytest.mark.parametrize("method", ["create_array", "array"])
631664
@pytest.mark.parametrize("name", ["a", "/a"])
632665
def test_group_create_array(

0 commit comments

Comments
 (0)