Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Unreleased
methods with V3 stores.
By :user:`Ryan Abernathey <rabernat>` :issue:`1228`.

* Add support for setting user-defined attributes at array / group creation time.
By :user: `Davis Bennett <d-v-b>` :issue:`538`.

.. _release_2.13.2:

Maintenance
Expand Down
14 changes: 13 additions & 1 deletion zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _prefix_to_group_key(store: StoreLike, prefix: str) -> str:
return key


def _prefix_to_attrs_key(store: StoreLike, prefix: str) -> str:
def _prefix_to_array_attrs_key(store: StoreLike, prefix: str) -> str:
if getattr(store, "_store_version", 2) == 3:
# for v3, attributes are stored in the array metadata
sfx = _get_metadata_suffix(store) # type: ignore
Expand All @@ -453,3 +453,15 @@ def _prefix_to_attrs_key(store: StoreLike, prefix: str) -> str:
else:
key = prefix + attrs_key
return key


def _prefix_to_group_attrs_key(store: StoreLike, prefix: str) -> str:
if getattr(store, "_store_version", 2) == 3:
sfx = _get_metadata_suffix(store) # type: ignore
if prefix:
key = meta_root + prefix.rstrip('/') + ".group" + sfx
else:
key = meta_root[:-1] + '.group' + sfx
else:
key = prefix + attrs_key
return key
55 changes: 31 additions & 24 deletions zarr/attrs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict
import warnings
from collections.abc import MutableMapping
from typing import MutableMapping

from zarr._storage.store import Store, StoreV3
from zarr.util import json_dumps


class Attributes(MutableMapping):
class Attributes(MutableMapping[str, Any]):
"""Class providing access to user attributes on an array or group. Should not be
instantiated directly, will be available via the `.attrs` property of an array or
group.
Expand All @@ -25,10 +26,16 @@ class Attributes(MutableMapping):

"""

def __init__(self, store, key='.zattrs', read_only=False, cache=True,
synchronizer=None):
def __init__(
self,
store,
key: str = ".zattrs",
read_only: bool = False,
cache: bool = True,
synchronizer=None,
):

self._version = getattr(store, '_store_version', 2)
self._version: int = getattr(store, "_store_version", 2)
_Store = Store if self._version == 2 else StoreV3
self.store = _Store._ensure_store(store)
self.key = key
Expand All @@ -39,11 +46,11 @@ def __init__(self, store, key='.zattrs', read_only=False, cache=True,

def _get_nosync(self):
try:
data = self.store[self.key]
data: bytes = self.store[self.key]
except KeyError:
d = dict()
if self._version > 2:
d['attributes'] = {}
d["attributes"] = {}
else:
d = self.store._metadata_class.parse_metadata(data)
return d
Expand All @@ -54,7 +61,7 @@ def asdict(self):
return self._cached_asdict
d = self._get_nosync()
if self._version == 3:
d = d['attributes']
d = d["attributes"]
if self.cache:
self._cached_asdict = d
return d
Expand All @@ -65,7 +72,7 @@ def refresh(self):
if self._version == 2:
self._cached_asdict = self._get_nosync()
else:
self._cached_asdict = self._get_nosync()['attributes']
self._cached_asdict = self._get_nosync()["attributes"]

def __contains__(self, x):
return x in self.asdict()
Expand All @@ -77,7 +84,7 @@ def _write_op(self, f, *args, **kwargs):

# guard condition
if self.read_only:
raise PermissionError('attributes are read-only')
raise PermissionError("attributes are read-only")

# synchronization
if self.synchronizer is None:
Expand All @@ -89,7 +96,7 @@ def _write_op(self, f, *args, **kwargs):
def __setitem__(self, item, value):
self._write_op(self._setitem_nosync, item, value)

def _setitem_nosync(self, item, value):
def _setitem_nosync(self, item: str, value):

# load existing data
d = self._get_nosync()
Expand All @@ -98,15 +105,15 @@ def _setitem_nosync(self, item, value):
if self._version == 2:
d[item] = value
else:
d['attributes'][item] = value
d["attributes"][item] = value

# _put modified data
self._put_nosync(d)

def __delitem__(self, item):
def __delitem__(self, item: str):
self._write_op(self._delitem_nosync, item)

def _delitem_nosync(self, key):
def _delitem_nosync(self, key: str):

# load existing data
d = self._get_nosync()
Expand All @@ -115,20 +122,20 @@ def _delitem_nosync(self, key):
if self._version == 2:
del d[key]
else:
del d['attributes'][key]
del d["attributes"][key]

# _put modified data
self._put_nosync(d)

def put(self, d):
def put(self, d: Dict[str, Any]):
"""Overwrite all attributes with the key/value pairs in the provided dictionary
`d` in a single operation."""
if self._version == 2:
self._write_op(self._put_nosync, d)
else:
self._write_op(self._put_nosync, dict(attributes=d))

def _put_nosync(self, d):
def _put_nosync(self, d: Dict[str, Any]):

d_to_check = d if self._version == 2 else d["attributes"]
if not all(isinstance(item, str) for item in d_to_check):
Expand All @@ -137,8 +144,8 @@ def _put_nosync(self, d):
warnings.warn(
"only attribute keys of type 'string' will be allowed in the future",
DeprecationWarning,
stacklevel=2
)
stacklevel=2,
)

try:
d_to_check = {str(k): v for k, v in d_to_check.items()}
Expand All @@ -163,15 +170,15 @@ def _put_nosync(self, d):
# Note: this changes the store.counter result in test_caching_on!

meta = self.store._metadata_class.parse_metadata(self.store[self.key])
if 'attributes' in meta and 'filters' in meta['attributes']:
if "attributes" in meta and "filters" in meta["attributes"]:
# need to preserve any existing "filters" attribute
d['attributes']['filters'] = meta['attributes']['filters']
meta['attributes'] = d['attributes']
d["attributes"]["filters"] = meta["attributes"]["filters"]
meta["attributes"] = d["attributes"]
else:
meta = d
self.store[self.key] = json_dumps(meta)
if self.cache:
self._cached_asdict = d['attributes']
self._cached_asdict = d["attributes"]

# noinspection PyMethodOverriding
def update(self, *args, **kwargs):
Expand All @@ -187,7 +194,7 @@ def _update_nosync(self, *args, **kwargs):
if self._version == 2:
d.update(*args, **kwargs)
else:
d['attributes'].update(*args, **kwargs)
d["attributes"].update(*args, **kwargs)

# _put modified data
self._put_nosync(d)
Expand Down
4 changes: 2 additions & 2 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from numcodecs.compat import ensure_bytes

from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available
from zarr._storage.store import _prefix_to_array_attrs_key, assert_zarr_v3_api_available
from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
from zarr.errors import ArrayNotFoundError, ReadOnlyError, ArrayIndexError
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
self._load_metadata()

# initialize attributes
akey = _prefix_to_attrs_key(self._store, self._key_prefix)
akey = _prefix_to_array_attrs_key(self._store, self._key_prefix)
self._attrs = Attributes(store, key=akey, read_only=read_only,
synchronizer=synchronizer, cache=cache_attrs)

Expand Down
Loading