diff --git a/docs/release.rst b/docs/release.rst index e15132b60b..b42ac764d2 100644 --- a/docs/release.rst +++ b/docs/release.rst @@ -18,6 +18,9 @@ Unreleased methods with V3 stores. By :user:`Ryan Abernathey ` :issue:`1228`. +* Add support for setting user-defined attributes at array / group creation time. + By :user: `Davis Bennett ` :issue:`538`. + .. _release_2.13.2: Maintenance diff --git a/zarr/_storage/store.py b/zarr/_storage/store.py index 9e265cf383..3ca049ca1c 100644 --- a/zarr/_storage/store.py +++ b/zarr/_storage/store.py @@ -442,7 +442,7 @@ def _prefix_to_group_key(store: StoreLike, prefix: str) -> str: return key -def _prefix_to_attrs_key(store: StoreLike, prefix: str) -> str: +def _prefix_to_array_attrs_key(store: StoreLike, prefix: str) -> str: if getattr(store, "_store_version", 2) == 3: # for v3, attributes are stored in the array metadata sfx = _get_metadata_suffix(store) # type: ignore @@ -453,3 +453,15 @@ def _prefix_to_attrs_key(store: StoreLike, prefix: str) -> str: else: key = prefix + attrs_key return key + + +def _prefix_to_group_attrs_key(store: StoreLike, prefix: str) -> str: + if getattr(store, "_store_version", 2) == 3: + sfx = _get_metadata_suffix(store) # type: ignore + if prefix: + key = meta_root + prefix.rstrip('/') + ".group" + sfx + else: + key = meta_root[:-1] + '.group' + sfx + else: + key = prefix + attrs_key + return key diff --git a/zarr/attrs.py b/zarr/attrs.py index 60dd7f1d79..3c80b541f2 100644 --- a/zarr/attrs.py +++ b/zarr/attrs.py @@ -1,11 +1,12 @@ +from typing import Any, Dict import warnings -from collections.abc import MutableMapping +from typing import MutableMapping from zarr._storage.store import Store, StoreV3 from zarr.util import json_dumps -class Attributes(MutableMapping): +class Attributes(MutableMapping[str, Any]): """Class providing access to user attributes on an array or group. Should not be instantiated directly, will be available via the `.attrs` property of an array or group. @@ -25,10 +26,16 @@ class Attributes(MutableMapping): """ - def __init__(self, store, key='.zattrs', read_only=False, cache=True, - synchronizer=None): + def __init__( + self, + store, + key: str = ".zattrs", + read_only: bool = False, + cache: bool = True, + synchronizer=None, + ): - self._version = getattr(store, '_store_version', 2) + self._version: int = getattr(store, "_store_version", 2) _Store = Store if self._version == 2 else StoreV3 self.store = _Store._ensure_store(store) self.key = key @@ -39,11 +46,11 @@ def __init__(self, store, key='.zattrs', read_only=False, cache=True, def _get_nosync(self): try: - data = self.store[self.key] + data: bytes = self.store[self.key] except KeyError: d = dict() if self._version > 2: - d['attributes'] = {} + d["attributes"] = {} else: d = self.store._metadata_class.parse_metadata(data) return d @@ -54,7 +61,7 @@ def asdict(self): return self._cached_asdict d = self._get_nosync() if self._version == 3: - d = d['attributes'] + d = d["attributes"] if self.cache: self._cached_asdict = d return d @@ -65,7 +72,7 @@ def refresh(self): if self._version == 2: self._cached_asdict = self._get_nosync() else: - self._cached_asdict = self._get_nosync()['attributes'] + self._cached_asdict = self._get_nosync()["attributes"] def __contains__(self, x): return x in self.asdict() @@ -77,7 +84,7 @@ def _write_op(self, f, *args, **kwargs): # guard condition if self.read_only: - raise PermissionError('attributes are read-only') + raise PermissionError("attributes are read-only") # synchronization if self.synchronizer is None: @@ -89,7 +96,7 @@ def _write_op(self, f, *args, **kwargs): def __setitem__(self, item, value): self._write_op(self._setitem_nosync, item, value) - def _setitem_nosync(self, item, value): + def _setitem_nosync(self, item: str, value): # load existing data d = self._get_nosync() @@ -98,15 +105,15 @@ def _setitem_nosync(self, item, value): if self._version == 2: d[item] = value else: - d['attributes'][item] = value + d["attributes"][item] = value # _put modified data self._put_nosync(d) - def __delitem__(self, item): + def __delitem__(self, item: str): self._write_op(self._delitem_nosync, item) - def _delitem_nosync(self, key): + def _delitem_nosync(self, key: str): # load existing data d = self._get_nosync() @@ -115,12 +122,12 @@ def _delitem_nosync(self, key): if self._version == 2: del d[key] else: - del d['attributes'][key] + del d["attributes"][key] # _put modified data self._put_nosync(d) - def put(self, d): + def put(self, d: Dict[str, Any]): """Overwrite all attributes with the key/value pairs in the provided dictionary `d` in a single operation.""" if self._version == 2: @@ -128,7 +135,7 @@ def put(self, d): else: self._write_op(self._put_nosync, dict(attributes=d)) - def _put_nosync(self, d): + def _put_nosync(self, d: Dict[str, Any]): d_to_check = d if self._version == 2 else d["attributes"] if not all(isinstance(item, str) for item in d_to_check): @@ -137,8 +144,8 @@ def _put_nosync(self, d): warnings.warn( "only attribute keys of type 'string' will be allowed in the future", DeprecationWarning, - stacklevel=2 - ) + stacklevel=2, + ) try: d_to_check = {str(k): v for k, v in d_to_check.items()} @@ -163,15 +170,15 @@ def _put_nosync(self, d): # Note: this changes the store.counter result in test_caching_on! meta = self.store._metadata_class.parse_metadata(self.store[self.key]) - if 'attributes' in meta and 'filters' in meta['attributes']: + if "attributes" in meta and "filters" in meta["attributes"]: # need to preserve any existing "filters" attribute - d['attributes']['filters'] = meta['attributes']['filters'] - meta['attributes'] = d['attributes'] + d["attributes"]["filters"] = meta["attributes"]["filters"] + meta["attributes"] = d["attributes"] else: meta = d self.store[self.key] = json_dumps(meta) if self.cache: - self._cached_asdict = d['attributes'] + self._cached_asdict = d["attributes"] # noinspection PyMethodOverriding def update(self, *args, **kwargs): @@ -187,7 +194,7 @@ def _update_nosync(self, *args, **kwargs): if self._version == 2: d.update(*args, **kwargs) else: - d['attributes'].update(*args, **kwargs) + d["attributes"].update(*args, **kwargs) # _put modified data self._put_nosync(d) diff --git a/zarr/core.py b/zarr/core.py index e5b2045160..29941950f6 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -10,7 +10,7 @@ import numpy as np from numcodecs.compat import ensure_bytes -from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available +from zarr._storage.store import _prefix_to_array_attrs_key, assert_zarr_v3_api_available from zarr.attrs import Attributes from zarr.codecs import AsType, get_codec from zarr.errors import ArrayNotFoundError, ReadOnlyError, ArrayIndexError @@ -215,7 +215,7 @@ def __init__( self._load_metadata() # initialize attributes - akey = _prefix_to_attrs_key(self._store, self._key_prefix) + akey = _prefix_to_array_attrs_key(self._store, self._key_prefix) self._attrs = Attributes(store, key=akey, read_only=read_only, synchronizer=synchronizer, cache=cache_attrs) diff --git a/zarr/creation.py b/zarr/creation.py index 00d2c40030..dcaaf06762 100644 --- a/zarr/creation.py +++ b/zarr/creation.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Literal, Optional, Sequence, Union from warnings import warn import numpy as np @@ -11,18 +11,42 @@ ContainsArrayError, ContainsGroupError, ) -from zarr.storage import (contains_array, contains_group, default_compressor, - init_array, normalize_storage_path, - normalize_store_arg) +from zarr.storage import ( + contains_array, + contains_group, + default_compressor, + init_array, + normalize_storage_path, + normalize_store_arg, +) from zarr.util import normalize_dimension_separator -def create(shape, chunks=True, dtype=None, compressor='default', - fill_value: Optional[int] = 0, order='C', store=None, synchronizer=None, - overwrite=False, path=None, chunk_store=None, filters=None, - cache_metadata=True, cache_attrs=True, read_only=False, - object_codec=None, dimension_separator=None, write_empty_chunks=True, - *, zarr_version=None, meta_array=None, **kwargs): +def create( + shape: Union[int, Sequence[int]], + chunks=True, + dtype=None, + compressor="default", + fill_value: Optional[int] = 0, + order: Literal["C", "F"] = "C", + store=None, + synchronizer=None, + overwrite: bool = False, + path: Optional[str] = None, + chunk_store=None, + filters=None, + cache_metadata: bool = True, + cache_attrs: bool = True, + read_only: bool = False, + object_codec=None, + dimension_separator: Optional[Literal["/", "."]] = None, + write_empty_chunks: bool = True, + attrs: Dict[str, Any] = {}, + *, + zarr_version=None, + meta_array=None, + **kwargs, +): """Create an array. Parameters @@ -71,6 +95,8 @@ def create(shape, chunks=True, dtype=None, compressor='default', A codec to encode object arrays, only needed if dtype=object. dimension_separator : {'.', '/'}, optional Separator placed between the dimensions of a chunk. + attrs : JSON-serializable dict. + User attributes for the array. Defaults to {}. .. versionadded:: 2.8 @@ -142,11 +168,11 @@ def create(shape, chunks=True, dtype=None, compressor='default', """ if zarr_version is None and store is None: - zarr_version = getattr(chunk_store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(chunk_store, "_store_version", DEFAULT_ZARR_VERSION) # handle polymorphic store arg store = normalize_store_arg(store, zarr_version=zarr_version) - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) # API compatibility with h5py compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs) @@ -160,22 +186,43 @@ def create(shape, chunks=True, dtype=None, compressor='default', raise ValueError( f"Specified dimension_separator: {dimension_separator}" f"conflicts with store's separator: " - f"{store_separator}") + f"{store_separator}" + ) dimension_separator = normalize_dimension_separator(dimension_separator) if zarr_version > 2 and path is None: - path = '/' + path = "/" # initialize array metadata - init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor, - fill_value=fill_value, order=order, overwrite=overwrite, path=path, - chunk_store=chunk_store, filters=filters, object_codec=object_codec, - dimension_separator=dimension_separator) + init_array( + store, + shape=shape, + attrs=attrs, + chunks=chunks, + dtype=dtype, + compressor=compressor, + fill_value=fill_value, + order=order, + overwrite=overwrite, + path=path, + chunk_store=chunk_store, + filters=filters, + object_codec=object_codec, + dimension_separator=dimension_separator, + ) # instantiate array - z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer, - cache_metadata=cache_metadata, cache_attrs=cache_attrs, read_only=read_only, - write_empty_chunks=write_empty_chunks, meta_array=meta_array) + z = Array( + store, + path=path, + chunk_store=chunk_store, + synchronizer=synchronizer, + cache_metadata=cache_metadata, + cache_attrs=cache_attrs, + read_only=read_only, + write_empty_chunks=write_empty_chunks, + meta_array=meta_array, + ) return z @@ -185,7 +232,7 @@ def _kwargs_compat(compressor, fill_value, kwargs): # to be compatible with h5py, as well as backwards-compatible with Zarr # 1.x, accept 'compression' and 'compression_opts' keyword arguments - if compressor != 'default': + if compressor != "default": # 'compressor' overrides 'compression' if "compression" in kwargs: warn( @@ -200,14 +247,14 @@ def _kwargs_compat(compressor, fill_value, kwargs): ) del kwargs["compression_opts"] - elif 'compression' in kwargs: - compression = kwargs.pop('compression') - compression_opts = kwargs.pop('compression_opts', None) + elif "compression" in kwargs: + compression = kwargs.pop("compression") + compression_opts = kwargs.pop("compression_opts", None) - if compression is None or compression == 'none': + if compression is None or compression == "none": compressor = None - elif compression == 'default': + elif compression == "default": compressor = default_compressor elif isinstance(compression, str): @@ -225,21 +272,21 @@ def _kwargs_compat(compressor, fill_value, kwargs): compressor = codec_cls(compression_opts) # be lenient here if user gives compressor as 'compression' - elif hasattr(compression, 'get_config'): + elif hasattr(compression, "get_config"): compressor = compression else: - raise ValueError('bad value for compression: %r' % compression) + raise ValueError("bad value for compression: %r" % compression) # handle 'fillvalue' - if 'fillvalue' in kwargs: + if "fillvalue" in kwargs: # to be compatible with h5py, accept 'fillvalue' instead of # 'fill_value' - fill_value = kwargs.pop('fillvalue') + fill_value = kwargs.pop("fillvalue") # ignore other keyword arguments for k in kwargs: - warn('ignoring keyword argument %r' % k) + warn("ignoring keyword argument %r" % k) return compressor, fill_value @@ -326,16 +373,17 @@ def _get_shape_chunks(a): shape = None chunks = None - if hasattr(a, 'shape') and \ - isinstance(a.shape, tuple): + if hasattr(a, "shape") and isinstance(a.shape, tuple): shape = a.shape - if hasattr(a, 'chunks') and \ - isinstance(a.chunks, tuple) and \ - (len(a.chunks) == len(a.shape)): + if ( + hasattr(a, "chunks") + and isinstance(a.chunks, tuple) + and (len(a.chunks) == len(a.shape)) + ): chunks = a.chunks - elif hasattr(a, 'chunklen'): + elif hasattr(a, "chunklen"): # bcolz carray chunks = (a.chunklen,) + a.shape[1:] @@ -360,27 +408,27 @@ def array(data, **kwargs): """ # ensure data is array-like - if not hasattr(data, 'shape') or not hasattr(data, 'dtype'): + if not hasattr(data, "shape") or not hasattr(data, "dtype"): data = np.asanyarray(data) # setup dtype - kw_dtype = kwargs.get('dtype') + kw_dtype = kwargs.get("dtype") if kw_dtype is None: - kwargs['dtype'] = data.dtype + kwargs["dtype"] = data.dtype else: - kwargs['dtype'] = kw_dtype + kwargs["dtype"] = kw_dtype # setup shape and chunks data_shape, data_chunks = _get_shape_chunks(data) - kwargs['shape'] = data_shape - kw_chunks = kwargs.get('chunks') + kwargs["shape"] = data_shape + kw_chunks = kwargs.get("chunks") if kw_chunks is None: - kwargs['chunks'] = data_chunks + kwargs["chunks"] = data_chunks else: - kwargs['chunks'] = kw_chunks + kwargs["chunks"] = kw_chunks # pop read-only to apply after storing the data - read_only = kwargs.pop('read_only', False) + read_only = kwargs.pop("read_only", False) # instantiate array z = create(**kwargs) @@ -413,10 +461,11 @@ def open_array( storage_options=None, partial_decompress=False, write_empty_chunks=True, + attrs: Dict[str, Any] = {}, *, zarr_version=None, dimension_separator=None, - **kwargs + **kwargs, ): """Open an array using file-mode-like semantics. @@ -478,6 +527,8 @@ def open_array( is deleted. This setting enables sparser storage, as only chunks with non-fill-value data are stored, at the expense of overhead associated with checking the data of each chunk. + attrs : JSON-serializable dict. + User attributes for the array. Defaults to {}. .. versionadded:: 2.11 @@ -525,27 +576,30 @@ def open_array( # a : read/write if exists, create otherwise (default) if zarr_version is None and store is None: - zarr_version = getattr(chunk_store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(chunk_store, "_store_version", DEFAULT_ZARR_VERSION) # handle polymorphic store arg - store = normalize_store_arg(store, storage_options=storage_options, - mode=mode, zarr_version=zarr_version) - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + store = normalize_store_arg( + store, storage_options=storage_options, mode=mode, zarr_version=zarr_version + ) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) if chunk_store is not None: - chunk_store = normalize_store_arg(chunk_store, - storage_options=storage_options, - mode=mode, - zarr_version=zarr_version) + chunk_store = normalize_store_arg( + chunk_store, + storage_options=storage_options, + mode=mode, + zarr_version=zarr_version, + ) # respect the dimension separator specified in a store, if present if dimension_separator is None: - if hasattr(store, '_dimension_separator'): + if hasattr(store, "_dimension_separator"): dimension_separator = store._dimension_separator else: - dimension_separator = '.' if zarr_version == 2 else '/' + dimension_separator = "." if zarr_version == 2 else "/" if zarr_version == 3 and path is None: - path = 'array' # TODO: raise ValueError instead? + path = "array" # TODO: raise ValueError instead? path = normalize_storage_path(path) @@ -557,49 +611,87 @@ def open_array( fill_value = np.array(fill_value, dtype=dtype)[()] # ensure store is initialized - - if mode in ['r', 'r+']: + # TODO: warning when creation kwargs (dtype, shape) are provided but mode is not w + if mode in ["r", "r+"]: if not contains_array(store, path=path): if contains_group(store, path=path): raise ContainsGroupError(path) raise ArrayNotFoundError(path) - elif mode == 'w': - init_array(store, shape=shape, chunks=chunks, dtype=dtype, - compressor=compressor, fill_value=fill_value, - order=order, filters=filters, overwrite=True, path=path, - object_codec=object_codec, chunk_store=chunk_store, - dimension_separator=dimension_separator) - - elif mode == 'a': + elif mode == "w": + init_array( + store, + shape=shape, + attrs=attrs, + chunks=chunks, + dtype=dtype, + compressor=compressor, + fill_value=fill_value, + order=order, + filters=filters, + overwrite=True, + path=path, + object_codec=object_codec, + chunk_store=chunk_store, + dimension_separator=dimension_separator, + ) + + elif mode == "a": if not contains_array(store, path=path): if contains_group(store, path=path): raise ContainsGroupError(path) - init_array(store, shape=shape, chunks=chunks, dtype=dtype, - compressor=compressor, fill_value=fill_value, - order=order, filters=filters, path=path, - object_codec=object_codec, chunk_store=chunk_store, - dimension_separator=dimension_separator) + init_array( + store, + shape=shape, + attrs=attrs, + chunks=chunks, + dtype=dtype, + compressor=compressor, + fill_value=fill_value, + order=order, + filters=filters, + path=path, + object_codec=object_codec, + chunk_store=chunk_store, + dimension_separator=dimension_separator + ) - elif mode in ['w-', 'x']: + elif mode in ["w-", "x"]: if contains_group(store, path=path): raise ContainsGroupError(path) elif contains_array(store, path=path): raise ContainsArrayError(path) else: - init_array(store, shape=shape, chunks=chunks, dtype=dtype, - compressor=compressor, fill_value=fill_value, - order=order, filters=filters, path=path, - object_codec=object_codec, chunk_store=chunk_store, - dimension_separator=dimension_separator) + init_array( + store, + shape=shape, + attrs=attrs, + chunks=chunks, + dtype=dtype, + compressor=compressor, + fill_value=fill_value, + order=order, + filters=filters, + path=path, + object_codec=object_codec, + chunk_store=chunk_store, + dimension_separator=dimension_separator + ) # determine read only status - read_only = mode == 'r' + read_only = mode == "r" # instantiate array - z = Array(store, read_only=read_only, synchronizer=synchronizer, - cache_metadata=cache_metadata, cache_attrs=cache_attrs, path=path, - chunk_store=chunk_store, write_empty_chunks=write_empty_chunks) + z = Array( + store, + read_only=read_only, + synchronizer=synchronizer, + cache_metadata=cache_metadata, + cache_attrs=cache_attrs, + path=path, + chunk_store=chunk_store, + write_empty_chunks=write_empty_chunks, + ) return z @@ -608,21 +700,21 @@ def _like_args(a, kwargs): shape, chunks = _get_shape_chunks(a) if shape is not None: - kwargs.setdefault('shape', shape) + kwargs.setdefault("shape", shape) if chunks is not None: - kwargs.setdefault('chunks', chunks) + kwargs.setdefault("chunks", chunks) - if hasattr(a, 'dtype'): - kwargs.setdefault('dtype', a.dtype) + if hasattr(a, "dtype"): + kwargs.setdefault("dtype", a.dtype) if isinstance(a, Array): - kwargs.setdefault('compressor', a.compressor) - kwargs.setdefault('order', a.order) - kwargs.setdefault('filters', a.filters) - kwargs.setdefault('zarr_version', a._version) + kwargs.setdefault("compressor", a.compressor) + kwargs.setdefault("order", a.order) + kwargs.setdefault("filters", a.filters) + kwargs.setdefault("zarr_version", a._version) else: - kwargs.setdefault('compressor', 'default') - kwargs.setdefault('order', 'C') + kwargs.setdefault("compressor", "default") + kwargs.setdefault("order", "C") def empty_like(a, **kwargs): @@ -647,7 +739,7 @@ def full_like(a, **kwargs): """Create a filled array like `a`.""" _like_args(a, kwargs) if isinstance(a, Array): - kwargs.setdefault('fill_value', a.fill_value) + kwargs.setdefault("fill_value", a.fill_value) return full(**kwargs) @@ -655,5 +747,5 @@ def open_like(a, path, **kwargs): """Open a persistent array like `a`.""" _like_args(a, kwargs) if isinstance(a, Array): - kwargs.setdefault('fill_value', a.fill_value) + kwargs.setdefault("fill_value", a.fill_value) return open_array(path, **kwargs) diff --git a/zarr/hierarchy.py b/zarr/hierarchy.py index 0dae921500..e99140a885 100644 --- a/zarr/hierarchy.py +++ b/zarr/hierarchy.py @@ -1,14 +1,30 @@ from collections.abc import MutableMapping from itertools import islice +from typing import Any, Dict, List, Optional, Tuple import numpy as np -from zarr._storage.store import (_get_metadata_suffix, data_root, meta_root, - DEFAULT_ZARR_VERSION, assert_zarr_v3_api_available) +from zarr._storage.store import ( + _get_metadata_suffix, + data_root, + meta_root, + DEFAULT_ZARR_VERSION, + assert_zarr_v3_api_available, +) from zarr.attrs import Attributes from zarr.core import Array -from zarr.creation import (array, create, empty, empty_like, full, full_like, - ones, ones_like, zeros, zeros_like) +from zarr.creation import ( + array, + create, + empty, + empty_like, + full, + full_like, + ones, + ones_like, + zeros, + zeros_like, +) from zarr.errors import ( ContainsArrayError, ContainsGroupError, @@ -120,32 +136,43 @@ class Group(MutableMapping): """ - def __init__(self, store, path=None, read_only=False, chunk_store=None, - cache_attrs=True, synchronizer=None, zarr_version=None, *, - meta_array=None): + def __init__( + self, + store, + path: Optional[str] = None, + read_only: bool = False, + chunk_store=None, + cache_attrs: bool = True, + synchronizer=None, + zarr_version: Optional[int] = None, + *, + meta_array=None + ): store: BaseStore = _normalize_store_arg(store, zarr_version=zarr_version) if zarr_version is None: - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) if zarr_version != 2: assert_zarr_v3_api_available() if chunk_store is not None: - chunk_store: BaseStore = _normalize_store_arg(chunk_store, zarr_version=zarr_version) + chunk_store: BaseStore = _normalize_store_arg( + chunk_store, zarr_version=zarr_version + ) self._store = store self._chunk_store = chunk_store self._path = normalize_storage_path(path) if self._path: - self._key_prefix = self._path + '/' + self._key_prefix = self._path + "/" else: - self._key_prefix = '' + self._key_prefix = "" self._read_only = read_only self._synchronizer = synchronizer if meta_array is not None: self._meta_array = np.empty_like(meta_array, shape=()) else: self._meta_array = np.empty(()) - self._version = zarr_version + self._version: int = zarr_version if self._version == 3: self._data_key_prefix = data_root + self._key_prefix self._data_path = data_root + self._path @@ -161,7 +188,7 @@ def __init__(self, store, path=None, read_only=False, chunk_store=None, try: mkey = _prefix_to_group_key(self._store, self._key_prefix) assert not mkey.endswith("root/.group") - meta_bytes = store[mkey] + meta_bytes: bytes = store[mkey] except KeyError: if self._version == 2: raise GroupNotFoundError(path) @@ -182,8 +209,13 @@ def __init__(self, store, path=None, read_only=False, chunk_store=None, # Note: mkey doesn't actually exist for implicit groups, but the # object can still be created. akey = mkey - self._attrs = Attributes(store, key=akey, read_only=read_only, - cache=cache_attrs, synchronizer=synchronizer) + self._attrs = Attributes( + store, + key=akey, + read_only=read_only, + cache=cache_attrs, + synchronizer=synchronizer, + ) # setup info self._info = InfoReporter(self) @@ -204,15 +236,15 @@ def name(self): if self._path: # follow h5py convention: add leading slash name = self._path - if name[0] != '/': - name = '/' + name + if name[0] != "/": + name = "/" + name return name - return '/' + return "/" @property def basename(self): """Final component of name.""" - return self.name.split('/')[-1] + return self.name.split("/")[-1] @property def read_only(self): @@ -250,12 +282,12 @@ def meta_array(self): """ return self._meta_array - def __eq__(self, other): + def __eq__(self, other: Any): return ( - isinstance(other, Group) and - self._store == other.store and - self._read_only == other.read_only and - self._path == other.path + isinstance(other, Group) + and self._store == other.store + and self._read_only == other.read_only + and self._path == other.path # N.B., no need to compare attributes, should be covered by # store comparison ) @@ -279,11 +311,12 @@ def __iter__(self): quux """ - if getattr(self._store, '_store_version', 2) == 2: + if getattr(self._store, "_store_version", 2) == 2: for key in sorted(listdir(self._store, self._path)): path = self._key_prefix + key - if (contains_array(self._store, path) or - contains_group(self._store, path)): + if contains_array(self._store, path) or contains_group( + self._store, path + ): yield key else: # TODO: Should this iterate over data folders and/or metadata @@ -296,15 +329,15 @@ def __iter__(self): # yield any groups or arrays sfx = self._metadata_key_suffix for key in keys: - len_suffix = len('.group') + len(sfx) # same for .array - if key.endswith(('.group' + sfx, '.array' + sfx)): + len_suffix = len(".group") + len(sfx) # same for .array + if key.endswith((".group" + sfx, ".array" + sfx)): yield key[name_start:-len_suffix] # also yield any implicit groups for prefix in prefixes: - prefix = prefix.rstrip('/') + prefix = prefix.rstrip("/") # only implicit if there is no .group.sfx file - if not prefix + '.group' + sfx in self._store: + if not prefix + ".group" + sfx in self._store: yield prefix[name_start:] # Note: omit data/root/ to avoid duplicate listings @@ -316,12 +349,12 @@ def __len__(self): def __repr__(self): t = type(self) - r = '<{}.{}'.format(t.__module__, t.__name__) + r = "<{}.{}".format(t.__module__, t.__name__) if self.name: - r += ' %r' % self.name + r += " %r" % self.name if self._read_only: - r += ' read-only' - r += '>' + r += " read-only" + r += ">" return r def __enter__(self): @@ -333,39 +366,38 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.store.close() def info_items(self): - def typestr(o): - return '{}.{}'.format(type(o).__module__, type(o).__name__) + return "{}.{}".format(type(o).__module__, type(o).__name__) items = [] # basic info if self.name is not None: - items += [('Name', self.name)] + items += [("Name", self.name)] items += [ - ('Type', typestr(self)), - ('Read-only', str(self.read_only)), + ("Type", typestr(self)), + ("Read-only", str(self.read_only)), ] # synchronizer if self._synchronizer is not None: - items += [('Synchronizer type', typestr(self._synchronizer))] + items += [("Synchronizer type", typestr(self._synchronizer))] # storage info - items += [('Store type', typestr(self._store))] + items += [("Store type", typestr(self._store))] if self._chunk_store is not None: - items += [('Chunk store type', typestr(self._chunk_store))] + items += [("Chunk store type", typestr(self._chunk_store))] # members - items += [('No. members', len(self))] + items += [("No. members", len(self))] array_keys = sorted(self.array_keys()) group_keys = sorted(self.group_keys()) - items += [('No. arrays', len(array_keys))] - items += [('No. groups', len(group_keys))] + items += [("No. arrays", len(array_keys))] + items += [("No. groups", len(group_keys))] if array_keys: - items += [('Arrays', ', '.join(array_keys))] + items += [("Arrays", ", ".join(array_keys))] if group_keys: - items += [('Groups', ', '.join(group_keys))] + items += [("Groups", ", ".join(group_keys))] return items @@ -384,14 +416,14 @@ def __getstate__(self): def __setstate__(self, state): self.__init__(**state) - def _item_path(self, item): - absolute = isinstance(item, str) and item and item[0] == '/' + def _item_path(self, item: str): + absolute = isinstance(item, str) and item and item[0] == "/" path = normalize_storage_path(item) if not absolute and self._path: path = self._key_prefix + path return path - def __contains__(self, item): + def __contains__(self, item: str): """Test for group membership. Examples @@ -409,10 +441,11 @@ def __contains__(self, item): """ path = self._item_path(item) - return contains_array(self._store, path) or \ - contains_group(self._store, path, explicit_only=False) + return contains_array(self._store, path) or contains_group( + self._store, path, explicit_only=False + ) - def __getitem__(self, item): + def __getitem__(self, item: str): """Obtain a group member. Parameters @@ -435,43 +468,62 @@ def __getitem__(self, item): """ path = self._item_path(item) if contains_array(self._store, path): - return Array(self._store, read_only=self._read_only, path=path, - chunk_store=self._chunk_store, - synchronizer=self._synchronizer, cache_attrs=self.attrs.cache, - zarr_version=self._version, meta_array=self._meta_array) + return Array( + self._store, + read_only=self._read_only, + path=path, + chunk_store=self._chunk_store, + synchronizer=self._synchronizer, + cache_attrs=self.attrs.cache, + zarr_version=self._version, + meta_array=self._meta_array, + ) elif contains_group(self._store, path, explicit_only=True): - return Group(self._store, read_only=self._read_only, path=path, - chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, - synchronizer=self._synchronizer, zarr_version=self._version, - meta_array=self._meta_array) + return Group( + self._store, + read_only=self._read_only, + path=path, + chunk_store=self._chunk_store, + cache_attrs=self.attrs.cache, + synchronizer=self._synchronizer, + zarr_version=self._version, + meta_array=self._meta_array, + ) elif self._version == 3: - implicit_group = meta_root + path + '/' + implicit_group = meta_root + path + "/" # non-empty folder in the metadata path implies an implicit group if self._store.list_prefix(implicit_group): - return Group(self._store, read_only=self._read_only, path=path, - chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, - synchronizer=self._synchronizer, zarr_version=self._version, - meta_array=self._meta_array) + return Group( + self._store, + read_only=self._read_only, + path=path, + chunk_store=self._chunk_store, + cache_attrs=self.attrs.cache, + synchronizer=self._synchronizer, + zarr_version=self._version, + meta_array=self._meta_array, + ) else: raise KeyError(item) else: raise KeyError(item) - def __setitem__(self, item, value): + def __setitem__(self, item: str, value): self.array(item, value, overwrite=True) - def __delitem__(self, item): + def __delitem__(self, item: str): return self._write_op(self._delitem_nosync, item) def _delitem_nosync(self, item): path = self._item_path(item) - if contains_array(self._store, path) or \ - contains_group(self._store, path, explicit_only=False): + if contains_array(self._store, path) or contains_group( + self._store, path, explicit_only=False + ): rmdir(self._store, path) else: raise KeyError(item) - def __getattr__(self, item): + def __getattr__(self, item: str): # allow access to group members via dot notation try: return self.__getitem__(item) @@ -482,7 +534,7 @@ def __dir__(self): # noinspection PyUnresolvedReferences base = super().__dir__() keys = sorted(set(base + list(self))) - keys = [k for k in keys if is_valid_python_name(k)] + keys: List[str] = [k for k in keys if is_valid_python_name(k)] return keys def _ipython_key_completions_(self): @@ -510,13 +562,13 @@ def group_keys(self): yield key else: dir_name = meta_root + self._path - group_sfx = '.group' + self._metadata_key_suffix + group_sfx = ".group" + self._metadata_key_suffix # The fact that we call sorted means this can't be a streaming generator. # The keys are already in memory. all_keys = sorted(listdir(self._store, dir_name)) for key in all_keys: if key.endswith(group_sfx): - key = key[:-len(group_sfx)] + key = key[: -len(group_sfx)] if key in all_keys: # otherwise we will double count this group continue @@ -555,7 +607,8 @@ def groups(self): chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, synchronizer=self._synchronizer, - zarr_version=self._version) + zarr_version=self._version, + ) else: for key in self.group_keys(): @@ -567,9 +620,10 @@ def groups(self): chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, synchronizer=self._synchronizer, - zarr_version=self._version) + zarr_version=self._version, + ) - def array_keys(self, recurse=False): + def array_keys(self, recurse: bool = False): """Return an iterator over member names for arrays only. Parameters @@ -591,11 +645,9 @@ def array_keys(self, recurse=False): ['baz', 'quux'] """ - return self._array_iter(keys_only=True, - method='array_keys', - recurse=recurse) + return self._array_iter(keys_only=True, method="array_keys", recurse=recurse) - def arrays(self, recurse=False): + def arrays(self, recurse: bool = False): """Return an iterator over (name, value) pairs for arrays only. Parameters @@ -619,11 +671,9 @@ def arrays(self, recurse=False): quux """ - return self._array_iter(keys_only=False, - method='arrays', - recurse=recurse) + return self._array_iter(keys_only=False, method="arrays", recurse=recurse) - def _array_iter(self, keys_only, method, recurse): + def _array_iter(self, keys_only: bool, method: str, recurse: bool): if self._version == 2: for key in sorted(listdir(self._store, self._path)): path = self._key_prefix + key @@ -635,12 +685,12 @@ def _array_iter(self, keys_only, method, recurse): yield from getattr(group, method)(recurse=recurse) else: dir_name = meta_root + self._path - array_sfx = '.array' + self._metadata_key_suffix - group_sfx = '.group' + self._metadata_key_suffix + array_sfx = ".array" + self._metadata_key_suffix + group_sfx = ".group" + self._metadata_key_suffix for key in sorted(listdir(self._store, dir_name)): if key.endswith(array_sfx): - key = key[:-len(array_sfx)] + key = key[: -len(array_sfx)] _key = key.rstrip("/") yield _key if keys_only else (_key, self[key]) @@ -794,8 +844,7 @@ def visit(self, func): return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/"))) def visitkeys(self, func): - """An alias for :py:meth:`~Group.visit`. - """ + """An alias for :py:meth:`~Group.visit`.""" return self.visit(func) @@ -830,7 +879,7 @@ def visititems(self, func): base_len = len(self.name) return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/"), o)) - def tree(self, expand=False, level=None): + def tree(self, expand: bool = False, level: Optional[int] = None): """Provide a ``print``-able display of the hierarchy. Parameters @@ -894,7 +943,9 @@ def _write_op(self, f, *args, **kwargs): with lock: return f(*args, **kwargs) - def create_group(self, name, overwrite=False): + def create_group( + self, name: str, overwrite: bool = False, attrs: Dict[str, Any] = {} + ): """Create a sub-group. Parameters @@ -918,24 +969,41 @@ def create_group(self, name, overwrite=False): """ - return self._write_op(self._create_group_nosync, name, overwrite=overwrite) + return self._write_op( + self._create_group_nosync, name, overwrite=overwrite, attrs=attrs + ) - def _create_group_nosync(self, name, overwrite=False): + def _create_group_nosync( + self, name: str, overwrite: bool = False, attrs: Dict[str, Any] = {} + ): path = self._item_path(name) # create terminal group - init_group(self._store, path=path, chunk_store=self._chunk_store, - overwrite=overwrite) + init_group( + self._store, + path=path, + chunk_store=self._chunk_store, + overwrite=overwrite, + attrs=attrs, + ) - return Group(self._store, path=path, read_only=self._read_only, - chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, - synchronizer=self._synchronizer, zarr_version=self._version) + return Group( + self._store, + path=path, + read_only=self._read_only, + chunk_store=self._chunk_store, + cache_attrs=self.attrs.cache, + synchronizer=self._synchronizer, + zarr_version=self._version, + ) - def create_groups(self, *names, **kwargs): + def create_groups(self, *names: str, **kwargs): """Convenience method to create multiple groups in a single call.""" return tuple(self.create_group(name, **kwargs) for name in names) - def require_group(self, name, overwrite=False): + def require_group( + self, name: str, overwrite: bool = False, attrs: Dict[str, Any] = {} + ): """Obtain a sub-group, creating one if it doesn't exist. Parameters @@ -960,27 +1028,39 @@ def require_group(self, name, overwrite=False): """ - return self._write_op(self._require_group_nosync, name, - overwrite=overwrite) + return self._write_op(self._require_group_nosync, name, overwrite=overwrite) - def _require_group_nosync(self, name, overwrite=False): + def _require_group_nosync( + self, name: str, overwrite: bool = False, attrs: Dict[str, Any] = {} + ): path = self._item_path(name) # create terminal group if necessary if not contains_group(self._store, path): - init_group(store=self._store, path=path, chunk_store=self._chunk_store, - overwrite=overwrite) + init_group( + store=self._store, + path=path, + chunk_store=self._chunk_store, + overwrite=overwrite, + attrs=attrs, + ) - return Group(self._store, path=path, read_only=self._read_only, - chunk_store=self._chunk_store, cache_attrs=self.attrs.cache, - synchronizer=self._synchronizer, zarr_version=self._version) + return Group( + self._store, + path=path, + read_only=self._read_only, + chunk_store=self._chunk_store, + cache_attrs=self.attrs.cache, + synchronizer=self._synchronizer, + zarr_version=self._version, + ) - def require_groups(self, *names): + def require_groups(self, *names: str): """Convenience method to require multiple groups in a single call.""" return tuple(self.require_group(name) for name in names) # noinspection PyIncorrectDocstring - def create_dataset(self, name, **kwargs): + def create_dataset(self, name: str, **kwargs): """Create an array. Arrays are known as "datasets" in HDF5 terminology. For compatibility @@ -1042,27 +1122,40 @@ def create_dataset(self, name, **kwargs): return self._write_op(self._create_dataset_nosync, name, **kwargs) - def _create_dataset_nosync(self, name, data=None, **kwargs): + def _create_dataset_nosync(self, name: str, data=None, **kwargs): assert "mode" not in kwargs path = self._item_path(name) # determine synchronizer - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) # create array if data is None: - a = create(store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + a = create( + store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) else: - a = array(data, store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + a = array( + data, + store=self._store, + path=path, + chunk_store=self._chunk_store, + **kwargs + ) return a - def require_dataset(self, name, shape, dtype=None, exact=False, **kwargs): + def require_dataset( + self, + name: str, + shape: Tuple[int, ...], + dtype: Optional[str] = None, + exact: bool = False, + **kwargs + ): """Obtain an array, creating if it doesn't exist. Arrays are known as "datasets" in HDF5 terminology. For compatibility @@ -1084,11 +1177,23 @@ def require_dataset(self, name, shape, dtype=None, exact=False, **kwargs): """ - return self._write_op(self._require_dataset_nosync, name, shape=shape, - dtype=dtype, exact=exact, **kwargs) + return self._write_op( + self._require_dataset_nosync, + name, + shape=shape, + dtype=dtype, + exact=exact, + **kwargs + ) - def _require_dataset_nosync(self, name, shape, dtype=None, exact=False, - **kwargs): + def _require_dataset_nosync( + self, + name: str, + shape: Tuple[int, ...], + dtype: Optional[str] = None, + exact: bool = False, + **kwargs + ): path = self._item_path(name) @@ -1096,157 +1201,184 @@ def _require_dataset_nosync(self, name, shape, dtype=None, exact=False, # array already exists at path, validate that it is the right shape and type - synchronizer = kwargs.get('synchronizer', self._synchronizer) - cache_metadata = kwargs.get('cache_metadata', True) - cache_attrs = kwargs.get('cache_attrs', self.attrs.cache) - a = Array(self._store, path=path, read_only=self._read_only, - chunk_store=self._chunk_store, synchronizer=synchronizer, - cache_metadata=cache_metadata, cache_attrs=cache_attrs, - meta_array=self._meta_array) + synchronizer = kwargs.get("synchronizer", self._synchronizer) + cache_metadata = kwargs.get("cache_metadata", True) + cache_attrs = kwargs.get("cache_attrs", self.attrs.cache) + a = Array( + self._store, + path=path, + read_only=self._read_only, + chunk_store=self._chunk_store, + synchronizer=synchronizer, + cache_metadata=cache_metadata, + cache_attrs=cache_attrs, + meta_array=self._meta_array, + ) shape = normalize_shape(shape) if shape != a.shape: - raise TypeError('shape do not match existing array; expected {}, got {}' - .format(a.shape, shape)) + raise TypeError( + "shape do not match existing array; expected {}, got {}".format( + a.shape, shape + ) + ) dtype = np.dtype(dtype) if exact: if dtype != a.dtype: - raise TypeError('dtypes do not match exactly; expected {}, got {}' - .format(a.dtype, dtype)) + raise TypeError( + "dtypes do not match exactly; expected {}, got {}".format( + a.dtype, dtype + ) + ) else: if not np.can_cast(dtype, a.dtype): - raise TypeError('dtypes ({}, {}) cannot be safely cast' - .format(dtype, a.dtype)) + raise TypeError( + "dtypes ({}, {}) cannot be safely cast".format(dtype, a.dtype) + ) return a else: - return self._create_dataset_nosync(name, shape=shape, dtype=dtype, - **kwargs) + return self._create_dataset_nosync(name, shape=shape, dtype=dtype, **kwargs) - def create(self, name, **kwargs): + def create(self, name: str, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.create`.""" return self._write_op(self._create_nosync, name, **kwargs) - def _create_nosync(self, name, **kwargs): + def _create_nosync(self, name: str, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return create(store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return create( + store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def empty(self, name, **kwargs): + def empty(self, name: str, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.empty`.""" return self._write_op(self._empty_nosync, name, **kwargs) - def _empty_nosync(self, name, **kwargs): + def _empty_nosync(self, name: str, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return empty(store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return empty( + store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def zeros(self, name, **kwargs): + def zeros(self, name: str, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.zeros`.""" return self._write_op(self._zeros_nosync, name, **kwargs) - def _zeros_nosync(self, name, **kwargs): + def _zeros_nosync(self, name: str, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return zeros(store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return zeros( + store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def ones(self, name, **kwargs): + def ones(self, name: str, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.ones`.""" return self._write_op(self._ones_nosync, name, **kwargs) - def _ones_nosync(self, name, **kwargs): + def _ones_nosync(self, name: str, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return ones(store=self._store, path=path, chunk_store=self._chunk_store, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return ones( + store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def full(self, name, fill_value, **kwargs): + def full(self, name: str, fill_value, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.full`.""" return self._write_op(self._full_nosync, name, fill_value, **kwargs) - def _full_nosync(self, name, fill_value, **kwargs): + def _full_nosync(self, name: str, fill_value, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return full(store=self._store, path=path, chunk_store=self._chunk_store, - fill_value=fill_value, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return full( + store=self._store, + path=path, + chunk_store=self._chunk_store, + fill_value=fill_value, + **kwargs + ) - def array(self, name, data, **kwargs): + def array(self, name: str, data, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.array`.""" return self._write_op(self._array_nosync, name, data, **kwargs) - def _array_nosync(self, name, data, **kwargs): + def _array_nosync(self, name: str, data, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return array(data, store=self._store, path=path, chunk_store=self._chunk_store, - **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return array( + data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def empty_like(self, name, data, **kwargs): + def empty_like(self, name: str, data, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.empty_like`.""" return self._write_op(self._empty_like_nosync, name, data, **kwargs) - def _empty_like_nosync(self, name, data, **kwargs): + def _empty_like_nosync(self, name: str, data, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return empty_like(data, store=self._store, path=path, - chunk_store=self._chunk_store, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return empty_like( + data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def zeros_like(self, name, data, **kwargs): + def zeros_like(self, name: str, data, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.zeros_like`.""" return self._write_op(self._zeros_like_nosync, name, data, **kwargs) - def _zeros_like_nosync(self, name, data, **kwargs): + def _zeros_like_nosync(self, name: str, data, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return zeros_like(data, store=self._store, path=path, - chunk_store=self._chunk_store, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return zeros_like( + data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def ones_like(self, name, data, **kwargs): + def ones_like(self, name: str, data, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.ones_like`.""" return self._write_op(self._ones_like_nosync, name, data, **kwargs) - def _ones_like_nosync(self, name, data, **kwargs): + def _ones_like_nosync(self, name: str, data, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return ones_like(data, store=self._store, path=path, - chunk_store=self._chunk_store, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return ones_like( + data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def full_like(self, name, data, **kwargs): + def full_like(self, name: str, data, **kwargs): """Create an array. Keyword arguments as per :func:`zarr.creation.full_like`.""" return self._write_op(self._full_like_nosync, name, data, **kwargs) - def _full_like_nosync(self, name, data, **kwargs): + def _full_like_nosync(self, name: str, data, **kwargs): path = self._item_path(name) - kwargs.setdefault('synchronizer', self._synchronizer) - kwargs.setdefault('cache_attrs', self.attrs.cache) - return full_like(data, store=self._store, path=path, - chunk_store=self._chunk_store, **kwargs) + kwargs.setdefault("synchronizer", self._synchronizer) + kwargs.setdefault("cache_attrs", self.attrs.cache) + return full_like( + data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs + ) - def _move_nosync(self, path, new_path): + def _move_nosync(self, path: str, new_path: str): rename(self._store, path, new_path) if self._chunk_store is not None: rename(self._chunk_store, path, new_path) - def move(self, source, dest): + def move(self, source: str, dest: str): """Move contents from one path to another relative to the Group. Parameters @@ -1261,11 +1393,14 @@ def move(self, source, dest): dest = self._item_path(dest) # Check that source exists. - if not (contains_array(self._store, source) or - contains_group(self._store, source, explicit_only=False)): + if not ( + contains_array(self._store, source) + or contains_group(self._store, source, explicit_only=False) + ): raise ValueError('The source, "%s", does not exist.' % source) - if (contains_array(self._store, dest) or - contains_group(self._store, dest, explicit_only=False)): + if contains_array(self._store, dest) or contains_group( + self._store, dest, explicit_only=False + ): raise ValueError('The dest, "%s", already exists.' % dest) # Ensure groups needed for `dest` exist. @@ -1275,23 +1410,33 @@ def move(self, source, dest): self._write_op(self._move_nosync, source, dest) -def _normalize_store_arg(store, *, storage_options=None, mode="r", - zarr_version=None): +def _normalize_store_arg( + store, *, storage_options=None, mode: str = "r", zarr_version: Optional[int] = None +): if zarr_version is None: - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) if zarr_version != 2: assert_zarr_v3_api_available() if store is None: return MemoryStore() if zarr_version == 2 else MemoryStoreV3() - return normalize_store_arg(store, - storage_options=storage_options, mode=mode, - zarr_version=zarr_version) - - -def group(store=None, overwrite=False, chunk_store=None, - cache_attrs=True, synchronizer=None, path=None, *, zarr_version=None): + return normalize_store_arg( + store, storage_options=storage_options, mode=mode, zarr_version=zarr_version + ) + + +def group( + store=None, + overwrite: bool = False, + chunk_store=None, + cache_attrs: bool = True, + synchronizer=None, + path: Optional[str] = None, + attrs: Dict[str, Any] = {}, + *, + zarr_version: Optional[int] = None +): """Create a group. Parameters @@ -1338,7 +1483,7 @@ def group(store=None, overwrite=False, chunk_store=None, # handle polymorphic store arg store = _normalize_store_arg(store, zarr_version=zarr_version) if zarr_version is None: - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) if zarr_version != 2: assert_zarr_v3_api_available() @@ -1352,16 +1497,34 @@ def group(store=None, overwrite=False, chunk_store=None, requires_init = overwrite or not contains_group(store, path) if requires_init: - init_group(store, overwrite=overwrite, chunk_store=chunk_store, - path=path) - - return Group(store, read_only=False, chunk_store=chunk_store, - cache_attrs=cache_attrs, synchronizer=synchronizer, path=path, - zarr_version=zarr_version) - + init_group( + store, overwrite=overwrite, chunk_store=chunk_store, path=path, attrs=attrs + ) -def open_group(store=None, mode='a', cache_attrs=True, synchronizer=None, path=None, - chunk_store=None, storage_options=None, *, zarr_version=None, meta_array=None): + return Group( + store, + read_only=False, + chunk_store=chunk_store, + cache_attrs=cache_attrs, + synchronizer=synchronizer, + path=path, + zarr_version=zarr_version, + ) + + +def open_group( + store=None, + mode: str = "a", + cache_attrs: bool = True, + synchronizer=None, + path: Optional[str] = None, + chunk_store=None, + storage_options=None, + attrs: Dict[str, Any] = {}, + *, + zarr_version: Optional[int] = None, + meta_array=None +): """Open a group using file-mode-like semantics. Parameters @@ -1414,20 +1577,22 @@ def open_group(store=None, mode='a', cache_attrs=True, synchronizer=None, path=N # handle polymorphic store arg store = _normalize_store_arg( - store, storage_options=storage_options, mode=mode, - zarr_version=zarr_version) + store, storage_options=storage_options, mode=mode, zarr_version=zarr_version + ) if zarr_version is None: - zarr_version = getattr(store, '_store_version', DEFAULT_ZARR_VERSION) + zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION) if zarr_version != 2: assert_zarr_v3_api_available() if chunk_store is not None: - chunk_store = _normalize_store_arg(chunk_store, - storage_options=storage_options, - mode=mode, - zarr_version=zarr_version) - if getattr(chunk_store, '_store_version', DEFAULT_ZARR_VERSION) != zarr_version: + chunk_store = _normalize_store_arg( + chunk_store, + storage_options=storage_options, + mode=mode, + zarr_version=zarr_version, + ) + if getattr(chunk_store, "_store_version", DEFAULT_ZARR_VERSION) != zarr_version: raise ValueError( # pragma: no cover "zarr_version of store and chunk_store must match" ) @@ -1436,32 +1601,41 @@ def open_group(store=None, mode='a', cache_attrs=True, synchronizer=None, path=N # ensure store is initialized - if mode in ['r', 'r+']: + if mode in ["r", "r+"]: if not contains_group(store, path=path): if contains_array(store, path=path): raise ContainsArrayError(path) raise GroupNotFoundError(path) - elif mode == 'w': - init_group(store, overwrite=True, path=path, chunk_store=chunk_store) + elif mode == "w": + init_group( + store, overwrite=True, path=path, chunk_store=chunk_store, attrs=attrs + ) - elif mode == 'a': + elif mode == "a": if not contains_group(store, path=path): if contains_array(store, path=path): raise ContainsArrayError(path) - init_group(store, path=path, chunk_store=chunk_store) + init_group(store, path=path, chunk_store=chunk_store, attrs=attrs) - elif mode in ['w-', 'x']: + elif mode in ["w-", "x"]: if contains_array(store, path=path): raise ContainsArrayError(path) elif contains_group(store, path=path): raise ContainsGroupError(path) else: - init_group(store, path=path, chunk_store=chunk_store) + init_group(store, path=path, chunk_store=chunk_store, attrs=attrs) # determine read only status - read_only = mode == 'r' - - return Group(store, read_only=read_only, cache_attrs=cache_attrs, - synchronizer=synchronizer, path=path, chunk_store=chunk_store, - zarr_version=zarr_version, meta_array=meta_array) + read_only = mode == "r" + + return Group( + store, + read_only=read_only, + cache_attrs=cache_attrs, + synchronizer=synchronizer, + path=path, + chunk_store=chunk_store, + zarr_version=zarr_version, + meta_array=meta_array, + ) diff --git a/zarr/storage.py b/zarr/storage.py index 4acf637330..98b485c73b 100644 --- a/zarr/storage.py +++ b/zarr/storage.py @@ -42,6 +42,7 @@ ensure_contiguous_ndarray_like ) from numcodecs.registry import codec_registry +from zarr.attrs import Attributes from zarr.errors import ( MetadataError, @@ -62,6 +63,8 @@ from zarr._storage.store import (_get_hierarchy_metadata, # noqa: F401 _get_metadata_suffix, _listdir_from_keys, + _prefix_to_array_attrs_key, + _prefix_to_group_attrs_key, _rename_from_keys, _rename_metadata_v3, _rmdir_from_keys, @@ -290,6 +293,7 @@ def _require_parent_group( def init_array( store: StoreLike, shape: Tuple[int, ...], + attrs: Dict[str, Any], chunks: Union[bool, int, Tuple[int, ...]] = True, dtype=None, compressor="default", @@ -300,7 +304,7 @@ def init_array( chunk_store: Optional[StoreLike] = None, filters=None, object_codec=None, - dimension_separator=None, + dimension_separator: Optional[str] = None, ): """Initialize an array store with the given configuration. Note that this is a low-level function and there should be no need to call this directly from user code. @@ -311,6 +315,8 @@ def init_array( A mapping that supports string keys and bytes-like values. shape : int or tuple of ints Array shape. + attrs : JSON-serializable dict. + User attributes for the array. chunks : bool, int or tuple of ints, optional Chunk shape. If True, will be guessed from `shape` and `dtype`. If False, will be set to `shape`, i.e., single chunk for the whole array. @@ -430,6 +436,28 @@ def init_array( object_codec=object_codec, dimension_separator=dimension_separator) + _init_array_attrs(store, path, attrs) + + +def _init_array_attrs(store: StoreLike, path: Optional[str], attrs: Dict[str, Any]): + if len(attrs): + if path: + key_prefix = path + '/' + else: + key_prefix = '' + akey = _prefix_to_array_attrs_key(store, key_prefix) + Attributes(store, key=akey, cache=False).put(attrs) + + +def _init_group_attrs(store: StoreLike, path: Optional[str], attrs: Dict[str, Any]): + if len(attrs): + if path: + key_prefix = path + '/' + else: + key_prefix = '' + akey = _prefix_to_group_attrs_key(store, key_prefix) + Attributes(store, key=akey, cache=False).put(attrs) + def _init_array_metadata( store: StoreLike, @@ -595,6 +623,7 @@ def _init_array_metadata( def init_group( store: StoreLike, + attrs: Dict[str, Any], overwrite: bool = False, path: Path = None, chunk_store: Optional[StoreLike] = None, @@ -613,7 +642,8 @@ def init_group( chunk_store : Store, optional Separate storage for chunks. If not provided, `store` will be used for storage of both chunks and metadata. - + attrs : JSON-serializable dict. + User attributes for the group. Defaults to {}. """ # normalize path @@ -633,6 +663,9 @@ def init_group( _init_group_metadata(store=store, overwrite=overwrite, path=path, chunk_store=chunk_store) + # initialize attrs + _init_group_attrs(store, path, attrs) + if store_version == 3: # TODO: Should initializing a v3 group also create a corresponding # empty folder under data/root/? I think probably not until there diff --git a/zarr/tests/test_core.py b/zarr/tests/test_core.py index e32026e662..1e9ddb3f22 100644 --- a/zarr/tests/test_core.py +++ b/zarr/tests/test_core.py @@ -68,7 +68,7 @@ def test_array_init(self): # normal initialization store = self.KVStoreClass(dict()) - init_array(store, shape=100, chunks=10, dtype=" Iterable: + +def flatten(arg: Iterable[Any]) -> Iterable[Any]: for element in arg: if isinstance(element, Iterable) and not isinstance(element, (str, bytes)): yield from flatten(element) @@ -27,14 +29,13 @@ def flatten(arg: Iterable) -> Iterable: # codecs to use for object dtype convenience API object_codecs = { - str.__name__: 'vlen-utf8', - bytes.__name__: 'vlen-bytes', - 'array': 'vlen-array', + str.__name__: "vlen-utf8", + bytes.__name__: "vlen-bytes", + "array": "vlen-array", } class NumberEncoder(json.JSONEncoder): - def default(self, o): # See json.JSONEncoder.default docstring for explanation # This is necessary to encode numpy dtype @@ -47,20 +48,26 @@ def default(self, o): def json_dumps(o: Any) -> bytes: """Write JSON in a consistent, human-readable way.""" - return json.dumps(o, indent=4, sort_keys=True, ensure_ascii=True, - separators=(',', ': '), cls=NumberEncoder).encode('ascii') + return json.dumps( + o, + indent=4, + sort_keys=True, + ensure_ascii=True, + separators=(",", ": "), + cls=NumberEncoder, + ).encode("ascii") -def json_loads(s: str) -> Dict[str, Any]: +def json_loads(s: str) -> Dict[str, JSON]: """Read JSON in a consistent way.""" - return json.loads(ensure_text(s, 'ascii')) + return json.loads(ensure_text(s, "ascii")) -def normalize_shape(shape) -> Tuple[int]: +def normalize_shape(shape: Union[int, Sequence[int]]) -> Tuple[int, ...]: """Convenience function to normalize the `shape` argument.""" if shape is None: - raise TypeError('shape is None') + raise TypeError("shape is None") # handle 1D convenience form if isinstance(shape, numbers.Integral): @@ -73,9 +80,9 @@ def normalize_shape(shape) -> Tuple[int]: # code to guess chunk shape, adapted from h5py -CHUNK_BASE = 256*1024 # Multiplier by which chunks are adjusted -CHUNK_MIN = 128*1024 # Soft lower limit (128k) -CHUNK_MAX = 64*1024*1024 # Hard upper limit +CHUNK_BASE = 256 * 1024 # Multiplier by which chunks are adjusted +CHUNK_MIN = 128 * 1024 # Soft lower limit (128k) +CHUNK_MAX = 64 * 1024 * 1024 # Hard upper limit def guess_chunks(shape: Tuple[int, ...], typesize: int) -> Tuple[int, ...]: @@ -89,12 +96,12 @@ def guess_chunks(shape: Tuple[int, ...], typesize: int) -> Tuple[int, ...]: ndims = len(shape) # require chunks to have non-zero length for all dimensions - chunks = np.maximum(np.array(shape, dtype='=f8'), 1) + chunks = np.maximum(np.array(shape, dtype="=f8"), 1) # Determine the optimal chunk size in bytes using a PyTables expression. # This is kept as a float. - dset_size = np.product(chunks)*typesize - target_size = CHUNK_BASE * (2**np.log10(dset_size/(1024.*1024))) + dset_size = np.product(chunks) * typesize + target_size = CHUNK_BASE * (2 ** np.log10(dset_size / (1024.0 * 1024))) if target_size > CHUNK_MAX: target_size = CHUNK_MAX @@ -108,11 +115,12 @@ def guess_chunks(shape: Tuple[int, ...], typesize: int) -> Tuple[int, ...]: # 1b. We're within 50% of the target chunk size, AND # 2. The chunk is smaller than the maximum chunk size - chunk_bytes = np.product(chunks)*typesize + chunk_bytes = np.product(chunks) * typesize - if (chunk_bytes < target_size or - abs(chunk_bytes-target_size)/target_size < 0.5) and \ - chunk_bytes < CHUNK_MAX: + if ( + chunk_bytes < target_size + or abs(chunk_bytes - target_size) / target_size < 0.5 + ) and chunk_bytes < CHUNK_MAX: break if np.product(chunks) == 1: @@ -146,7 +154,7 @@ def normalize_chunks( # handle bad dimensionality if len(chunks) > len(shape): - raise ValueError('too many dimensions in chunks') + raise ValueError("too many dimensions in chunks") # handle underspecified chunks if len(chunks) < len(shape): @@ -155,49 +163,59 @@ def normalize_chunks( # handle None or -1 in chunks if -1 in chunks or None in chunks: - chunks = tuple(s if c == -1 or c is None else int(c) - for s, c in zip(shape, chunks)) + chunks = tuple( + s if c == -1 or c is None else int(c) for s, c in zip(shape, chunks) + ) return tuple(chunks) -def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype, Any]: +def normalize_dtype( + dtype: Union[str, np.dtype], object_codec: Any +) -> Tuple[np.dtype, Any]: # convenience API for object arrays if inspect.isclass(dtype): dtype = dtype.__name__ # type: ignore if isinstance(dtype, str): # allow ':' to delimit class from codec arguments - tokens = dtype.split(':') + tokens = dtype.split(":") key = tokens[0] if key in object_codecs: dtype = np.dtype(object) if object_codec is None: codec_id = object_codecs[key] if len(tokens) > 1: - args = tokens[1].split(',') + args = tokens[1].split(",") else: args = [] try: object_codec = codec_registry[codec_id](*args) except KeyError: # pragma: no cover - raise ValueError('codec %r for object type %r is not ' - 'available; please provide an ' - 'object_codec manually' % (codec_id, key)) + raise ValueError( + "codec %r for object type %r is not " + "available; please provide an " + "object_codec manually" % (codec_id, key) + ) return dtype, object_codec dtype = np.dtype(dtype) # don't allow generic datetime64 or timedelta64, require units to be specified - if dtype == np.dtype('M8') or dtype == np.dtype('m8'): - raise ValueError('datetime64 and timedelta64 dtypes with generic units ' - 'are not supported, please specify units (e.g., "M8[ns]")') + if dtype == np.dtype("M8") or dtype == np.dtype("m8"): + raise ValueError( + "datetime64 and timedelta64 dtypes with generic units " + 'are not supported, please specify units (e.g., "M8[ns]")' + ) return dtype, object_codec # noinspection PyTypeChecker -def is_total_slice(item, shape: Tuple[int]) -> bool: +# TODO: correctly type ellipsis +def is_total_slice( + item: Union[slice, Tuple[Any, ...]], shape: Tuple[int] +) -> bool: """Determine whether `item` specifies a complete slice of array with the given `shape`. Used to optimize __setitem__ operations on the Chunk class.""" @@ -209,19 +227,23 @@ def is_total_slice(item, shape: Tuple[int]) -> bool: if item == slice(None): return True if isinstance(item, slice): - item = item, + item = (item,) if isinstance(item, tuple): return all( - (isinstance(s, slice) and - ((s == slice(None)) or - ((s.stop - s.start == l) and (s.step in [1, None])))) + ( + isinstance(s, slice) + and ( + (s == slice(None)) + or ((s.stop - s.start == l) and (s.step in [1, None])) + ) + ) for s, l in zip(item, shape) ) else: - raise TypeError('expected slice or tuple of slices, found %r' % item) + raise TypeError("expected slice or tuple of slices, found %r" % item) -def normalize_resize_args(old_shape, *args): +def normalize_resize_args(old_shape: Tuple[int, ...], *args): # normalize new shape argument if len(args) == 1: @@ -233,33 +255,32 @@ def normalize_resize_args(old_shape, *args): else: new_shape = tuple(new_shape) if len(new_shape) != len(old_shape): - raise ValueError('new shape must have same number of dimensions') + raise ValueError("new shape must have same number of dimensions") # handle None in new_shape - new_shape = tuple(s if n is None else int(n) - for s, n in zip(old_shape, new_shape)) + new_shape = tuple(s if n is None else int(n) for s, n in zip(old_shape, new_shape)) return new_shape -def human_readable_size(size) -> str: +def human_readable_size(size: int) -> str: if size < 2**10: - return '%s' % size + return "%s" % size elif size < 2**20: - return '%.1fK' % (size / float(2**10)) + return "%.1fK" % (size / float(2**10)) elif size < 2**30: - return '%.1fM' % (size / float(2**20)) + return "%.1fM" % (size / float(2**20)) elif size < 2**40: - return '%.1fG' % (size / float(2**30)) + return "%.1fG" % (size / float(2**30)) elif size < 2**50: - return '%.1fT' % (size / float(2**40)) + return "%.1fT" % (size / float(2**40)) else: - return '%.1fP' % (size / float(2**50)) + return "%.1fP" % (size / float(2**50)) def normalize_order(order: str) -> str: order = str(order).upper() - if order not in ['C', 'F']: + if order not in ["C", "F"]: raise ValueError("order must be either 'C' or 'F', found: %r" % order) return order @@ -269,10 +290,11 @@ def normalize_dimension_separator(sep: Optional[str]) -> Optional[str]: return sep else: raise ValueError( - "dimension_separator must be either '.' or '/', found: %r" % sep) + "dimension_separator must be either '.' or '/', found: %r" % sep + ) -def normalize_fill_value(fill_value, dtype: np.dtype): +def normalize_fill_value(fill_value: Any, dtype: np.dtype): if fill_value is None or dtype.hasobject: # no fill value @@ -282,17 +304,19 @@ def normalize_fill_value(fill_value, dtype: np.dtype): # structured arrays fill_value = np.zeros((), dtype=dtype)[()] - elif dtype.kind == 'U': + elif dtype.kind == "U": # special case unicode because of encoding issues on Windows if passed through numpy # https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713 if not isinstance(fill_value, str): - raise ValueError('fill_value {!r} is not valid for dtype {}; must be a ' - 'unicode string'.format(fill_value, dtype)) + raise ValueError( + "fill_value {!r} is not valid for dtype {}; must be a " + "unicode string".format(fill_value, dtype) + ) else: try: - if isinstance(fill_value, bytes) and dtype.kind == 'V': + if isinstance(fill_value, bytes) and dtype.kind == "V": # special case for numpy 1.14 compatibility fill_value = np.array(fill_value, dtype=dtype.str).view(dtype)[()] else: @@ -300,8 +324,10 @@ def normalize_fill_value(fill_value, dtype: np.dtype): except Exception as e: # re-raise with our own error message to be helpful - raise ValueError('fill_value {!r} is not valid for dtype {}; nested ' - 'exception: {}'.format(fill_value, dtype, e)) + raise ValueError( + "fill_value {!r} is not valid for dtype {}; nested " + "exception: {}".format(fill_value, dtype, e) + ) return fill_value @@ -310,7 +336,7 @@ def normalize_storage_path(path: Union[str, bytes, None]) -> str: # handle bytes if isinstance(path, bytes): - path = str(path, 'ascii') + path = str(path, "ascii") # ensure str if path is not None and not isinstance(path, str): @@ -319,21 +345,21 @@ def normalize_storage_path(path: Union[str, bytes, None]) -> str: if path: # convert backslash to forward slash - path = path.replace('\\', '/') + path = path.replace("\\", "/") # ensure no leading slash - while len(path) > 0 and path[0] == '/': + while len(path) > 0 and path[0] == "/": path = path[1:] # ensure no trailing slash - while len(path) > 0 and path[-1] == '/': + while len(path) > 0 and path[-1] == "/": path = path[:-1] # collapse any repeated slashes previous_char = None - collapsed = '' + collapsed = "" for char in path: - if char == '/' and previous_char == '/': + if char == "/" and previous_char == "/": pass else: collapsed += char @@ -341,49 +367,51 @@ def normalize_storage_path(path: Union[str, bytes, None]) -> str: path = collapsed # don't allow path segments with just '.' or '..' - segments = path.split('/') - if any(s in {'.', '..'} for s in segments): + segments = path.split("/") + if any(s in {".", ".."} for s in segments): raise ValueError("path containing '.' or '..' segment not allowed") else: - path = '' + path = "" return path -def buffer_size(v) -> int: +def buffer_size(v: Any) -> int: return ensure_ndarray_like(v).nbytes def info_text_report(items: Dict[Any, Any]) -> str: keys = [k for k, v in items] max_key_len = max(len(k) for k in keys) - report = '' + report = "" for k, v in items: - wrapper = TextWrapper(width=80, - initial_indent=k.ljust(max_key_len) + ' : ', - subsequent_indent=' '*max_key_len + ' : ') + wrapper = TextWrapper( + width=80, + initial_indent=k.ljust(max_key_len) + " : ", + subsequent_indent=" " * max_key_len + " : ", + ) text = wrapper.fill(str(v)) - report += text + '\n' + report += text + "\n" return report -def info_html_report(items) -> str: +def info_html_report(items: Dict[Any, Any]) -> str: report = '' - report += '' + report += "" for k, v in items: - report += '' \ - '' \ - '' \ - '' \ - % (k, v) - report += '' - report += '
%s%s
' + report += ( + "" + '%s' + '%s' + "" % (k, v) + ) + report += "" + report += "" return report class InfoReporter: - def __init__(self, obj): self.obj = obj @@ -397,24 +425,25 @@ def _repr_html_(self): class TreeNode: - - def __init__(self, obj, depth=0, level=None): + def __init__(self, obj, depth: int = 0, level=None): self.obj = obj self.depth = depth self.level = level def get_children(self): - if hasattr(self.obj, 'values'): + if hasattr(self.obj, "values"): if self.level is None or self.depth < self.level: depth = self.depth + 1 - return [TreeNode(o, depth=depth, level=self.level) - for o in self.obj.values()] + return [ + TreeNode(o, depth=depth, level=self.level) + for o in self.obj.values() + ] return [] def get_text(self): name = self.obj.name.split("/")[-1] or "/" - if hasattr(self.obj, 'shape'): - name += ' {} {}'.format(self.obj.shape, self.obj.dtype) + if hasattr(self.obj, "shape"): + name += " {} {}".format(self.obj.shape, self.obj.dtype) return name def get_type(self): @@ -422,7 +451,6 @@ def get_type(self): class TreeTraversal(Traversal): - def get_children(self, node): return node.get_children() @@ -433,8 +461,8 @@ def get_text(self, node): return node.get_text() -tree_group_icon = 'folder' -tree_array_icon = 'table' +tree_group_icon = "folder" +tree_array_icon = "table" def tree_get_icon(stype: str) -> str: @@ -481,37 +509,29 @@ def tree_widget(group, expand, level): class TreeViewer: - def __init__(self, group, expand=False, level=None): self.group = group self.expand = expand self.level = level - self.text_kwargs = dict( - horiz_len=2, - label_space=1, - indent=1 - ) + self.text_kwargs = dict(horiz_len=2, label_space=1, indent=1) self.bytes_kwargs = dict( - UP_AND_RIGHT="+", - HORIZONTAL="-", - VERTICAL="|", - VERTICAL_AND_RIGHT="+" + UP_AND_RIGHT="+", HORIZONTAL="-", VERTICAL="|", VERTICAL_AND_RIGHT="+" ) self.unicode_kwargs = dict( UP_AND_RIGHT="\u2514", HORIZONTAL="\u2500", VERTICAL="\u2502", - VERTICAL_AND_RIGHT="\u251C" + VERTICAL_AND_RIGHT="\u251C", ) def __bytes__(self): drawer = LeftAligned( traverse=TreeTraversal(), - draw=BoxStyle(gfx=self.bytes_kwargs, **self.text_kwargs) + draw=BoxStyle(gfx=self.bytes_kwargs, **self.text_kwargs), ) root = TreeNode(self.group, level=self.level) result = drawer(root) @@ -525,7 +545,7 @@ def __bytes__(self): def __unicode__(self): drawer = LeftAligned( traverse=TreeTraversal(), - draw=BoxStyle(gfx=self.unicode_kwargs, **self.text_kwargs) + draw=BoxStyle(gfx=self.unicode_kwargs, **self.text_kwargs), ) root = TreeNode(self.group, level=self.level) return drawer(root) @@ -538,17 +558,24 @@ def _repr_mimebundle_(self, **kwargs): return tree._repr_mimebundle_(**kwargs) -def check_array_shape(param, array, shape): - if not hasattr(array, 'shape'): - raise TypeError('parameter {!r}: expected an array-like object, got {!r}' - .format(param, type(array))) +def check_array_shape(param, array: Any, shape: Tuple[int, ...]): + if not hasattr(array, "shape"): + raise TypeError( + "parameter {!r}: expected an array-like object, got {!r}".format( + param, type(array) + ) + ) if array.shape != shape: - raise ValueError('parameter {!r}: expected array with shape {!r}, got {!r}' - .format(param, shape, array.shape)) + raise ValueError( + "parameter {!r}: expected array with shape {!r}, got {!r}".format( + param, shape, array.shape + ) + ) -def is_valid_python_name(name): +def is_valid_python_name(name: str): from keyword import iskeyword + return name.isidentifier() and not iskeyword(name) @@ -566,7 +593,7 @@ def __exit__(self, *args): class PartialReadBuffer: - def __init__(self, store_key, chunk_store): + def __init__(self, store_key: str, chunk_store): self.chunk_store = chunk_store # is it fsstore or an actual fsspec map object assert hasattr(self.chunk_store, "map") @@ -578,12 +605,12 @@ def __init__(self, store_key, chunk_store): self.start_points = None self.n_per_block = None self.start_points_max = None - self.read_blocks = set() + self.read_blocks: Set[int] = set() _key_path = self.map._key_to_str(store_key) - _key_path = _key_path.split('/') + _key_path = _key_path.split("/") _chunk_path = [self.chunk_store._normalize_key(_key_path[-1])] - _key_path = '/'.join(_key_path[:-1] + _chunk_path) + _key_path = "/".join(_key_path[:-1] + _chunk_path) self.key_path = _key_path def prepare_chunk(self): @@ -593,7 +620,7 @@ def prepare_chunk(self): typesize, _shuffle, _memcpyd = cbuffer_metainfo(header) self.buff = mmap.mmap(-1, self.cbytes) self.buff[0:16] = header - self.nblocks = nbytes / blocksize + self.nblocks: int = nbytes / blocksize self.nblocks = ( int(self.nblocks) if self.nblocks == int(self.nblocks) @@ -612,7 +639,7 @@ def prepare_chunk(self): self.buff[16: (16 + (self.nblocks * 4))] = start_points_buffer self.n_per_block = blocksize / typesize - def read_part(self, start, nitems): + def read_part(self, start, nitems: int): assert self.buff is not None if self.nblocks == 1: return @@ -639,12 +666,14 @@ def read_full(self): return self.chunk_store[self.store_key] -def retry_call(callabl: Callable, - args=None, - kwargs=None, - exceptions: Tuple[Any, ...] = (), - retries: int = 10, - wait: float = 0.1) -> Any: +def retry_call( + callabl: Callable, + args=None, + kwargs=None, + exceptions: Tuple[Any, ...] = (), + retries: int = 10, + wait: float = 0.1, +) -> Any: """ Make several attempts to invoke the callable. If one of the given exceptions is raised, wait the given period of time and retry up to the given number of @@ -656,7 +685,7 @@ def retry_call(callabl: Callable, if kwargs is None: kwargs = {} - for attempt in range(1, retries+1): + for attempt in range(1, retries + 1): try: return callabl(*args, **kwargs) except exceptions: