diff --git a/src/zarr/abc/metadata.py b/src/zarr/abc/metadata.py index 36edf69534..884426ca77 100644 --- a/src/zarr/abc/metadata.py +++ b/src/zarr/abc/metadata.py @@ -13,7 +13,7 @@ @dataclass(frozen=True) class Metadata: - def to_dict(self) -> JSON: + def to_dict(self) -> dict[str, JSON]: """ Recursively serialize this model to a dictionary. This method inspects the fields of self and calls `x.to_dict()` for any fields that @@ -37,7 +37,7 @@ def to_dict(self) -> JSON: return out_dict @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: + def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self: """ Create an instance of the model from a dictionary """ diff --git a/src/zarr/array.py b/src/zarr/array.py index 7da39c285e..cc51b31e67 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -273,13 +273,13 @@ async def _create_v2( return array @classmethod - def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + async def from_dict( + cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None ) -> AsyncArray: - metadata = parse_array_metadata(data) - async_array = cls(metadata=metadata, store_path=store_path) + data_parsed = parse_array_metadata(data) + async_array = cls(metadata=data_parsed, store_path=store_path, order=order) + # weird that this method doesn't use the metadata attribute + await async_array._save_metadata(async_array.metadata) return async_array @classmethod @@ -535,11 +535,9 @@ def create( @classmethod def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None ) -> Array: - async_array = AsyncArray.from_dict(store_path=store_path, data=data) + async_array = sync(AsyncArray.from_dict(store_path=store_path, data=data)) return cls(async_array) @classmethod diff --git a/src/zarr/chunk_key_encodings.py b/src/zarr/chunk_key_encodings.py index ed6c181764..32da2c29e7 100644 --- a/src/zarr/chunk_key_encodings.py +++ b/src/zarr/chunk_key_encodings.py @@ -26,7 +26,7 @@ def parse_separator(data: JSON) -> SeparatorLiteral: @dataclass(frozen=True) class ChunkKeyEncoding(Metadata): name: str - separator: SeparatorLiteral = "." + separator: SeparatorLiteral = "/" def __init__(self, *, separator: SeparatorLiteral) -> None: separator_parsed = parse_separator(separator) diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index 893cbc8b4b..d9ccd508fa 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -84,7 +84,7 @@ def from_dict(cls, data: Iterable[JSON | Codec], *, batch_size: int | None = Non out.append(get_codec_class(name_parsed).from_dict(c)) # type: ignore[arg-type] return cls.from_list(out, batch_size=batch_size) - def to_dict(self) -> JSON: + def to_dict(self) -> list[JSON]: return [c.to_dict() for c in self] def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: diff --git a/src/zarr/group.py b/src/zarr/group.py index 4ff2176fd9..8d7be0ba56 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -32,6 +32,8 @@ from collections.abc import AsyncGenerator, Iterable from typing import Any, Literal + from typing_extensions import Self + logger = logging.getLogger("zarr.group") @@ -97,7 +99,7 @@ def __init__(self, attributes: dict[str, Any] | None = None, zarr_format: ZarrFo object.__setattr__(self, "zarr_format", zarr_format_parsed) @classmethod - def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: + def from_dict(cls, data: dict[str, Any]) -> Self: assert data.pop("node_type", None) in ("group", None) return cls(**data) @@ -181,10 +183,10 @@ async def open( assert zarr_json_bytes is not None group_metadata = json.loads(zarr_json_bytes.to_bytes()) - return cls.from_dict(store_path, group_metadata) + return await cls.from_dict(store_path, group_metadata) @classmethod - def from_dict( + async def from_dict( cls, store_path: StorePath, data: dict[str, Any], @@ -215,9 +217,9 @@ async def getitem( else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) if zarr_json["node_type"] == "group": - return type(self).from_dict(store_path, zarr_json) + return await type(self).from_dict(store_path, zarr_json) elif zarr_json["node_type"] == "array": - return AsyncArray.from_dict(store_path, zarr_json) + return await AsyncArray.from_dict(store_path, zarr_json) else: raise ValueError(f"unexpected node_type: {zarr_json['node_type']}") elif self.metadata.zarr_format == 2: @@ -240,7 +242,7 @@ async def getitem( if zarray is not None: # TODO: update this once the V2 array support is part of the primary array class zarr_json = {**zarray, "attributes": zattrs} - return AsyncArray.from_dict(store_path, zarray) + return sync(AsyncArray.from_dict(store_path, zarray)) else: zgroup = ( json.loads(zgroup_bytes.to_bytes()) @@ -248,7 +250,7 @@ async def getitem( else {"zarr_format": self.metadata.zarr_format} ) zarr_json = {**zgroup, "attributes": zattrs} - return type(self).from_dict(store_path, zarr_json) + return await type(self).from_dict(store_path, zarr_json) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") diff --git a/src/zarr/hierarchy.py b/src/zarr/hierarchy.py new file mode 100644 index 0000000000..38f3039d04 --- /dev/null +++ b/src/zarr/hierarchy.py @@ -0,0 +1,366 @@ +""" +Copyright © 2023 Howard Hughes Medical Institute + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + Neither the name of HHMI nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +import numpy as np +from typing_extensions import Self + +from zarr.abc.codec import CodecPipeline +from zarr.array import Array +from zarr.buffer import NDBuffer +from zarr.chunk_grids import ChunkGrid, RegularChunkGrid +from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding +from zarr.codecs.bytes import BytesCodec +from zarr.group import Group, GroupMetadata +from zarr.metadata import ArrayV3Metadata +from zarr.store.core import StorePath +from zarr.v2.util import guess_chunks + + +def auto_data_type(data: Any) -> Any: + if hasattr(data, "dtype"): + if hasattr(data, "data_type"): + msg = ( + f"Could not infer the data_type attribute from {data}, because " + "it has both `dtype` and `data_type` attributes. " + "This method requires input with one, or the other, of these attributes." + ) + raise ValueError(msg) + return data.dtype + elif hasattr(data, "data_type") and not hasattr(data, "dtype"): + return data.data_type + else: + msg = ( + f"Could not infer the data_type attribute from {data}. " + "Expected either an object with a `dtype` attribute, " + "or an object with a `data_type` attribute." + ) + raise ValueError(msg) + + +def auto_attributes(data: Any) -> Any: + """ + Guess attributes from: + input with an `attrs` attribute, or + input with an `attributes` attribute, + or anything (returning {}) + """ + if hasattr(data, "attrs"): + return data.attrs + if hasattr(data, "attributes"): + return data.attributes + return {} + + +def auto_chunk_key_encoding(data: Any) -> Any: + if hasattr(data, "chunk_key_encoding"): + return data.chunk_key_encoding + return DefaultChunkKeyEncoding() + + +def auto_fill_value(data: Any) -> Any: + """ + Guess fill value from an input with a `fill_value` attribute, returning 0 otherwise. + """ + if hasattr(data, "fill_value"): + return data.fill_value + return 0 + + +def auto_codecs(data: Any) -> Any: + """ + Guess compressor from an input with a `compressor` attribute, returning `None` otherwise. + """ + if hasattr(data, "codecs"): + return data.codecs + return (BytesCodec(),) + + +def auto_dimension_names(data: Any) -> Any: + """ + If the input has a `dimension_names` attribute, return it, otherwise + return None. + """ + + if hasattr(data, "dimension_names"): + return data.dimension_names + return None + + +def auto_chunk_grid(data: Any) -> Any: + """ + Guess a chunk grid from: + input with a `chunk_grid` attribute, + input with a `chunksize` attribute, or + input with a `chunks` attribute, or, + input with `shape` and `dtype` attributes + """ + if hasattr(data, "chunk_grid"): + # more a statement of intent than anything else + return data.chunk_grid + if hasattr(data, "chunksize"): + chunks = data.chunksize + elif hasattr(data, "chunks"): + chunks = data.chunks + else: + chunks = guess_chunks(data.shape, np.dtype(data.dtype).itemsize) + return RegularChunkGrid(chunk_shape=chunks) + + +class ArrayModel(ArrayV3Metadata): + """ + A model of a Zarr v3 array. + """ + + @classmethod + def from_stored(cls: type[Self], node: Array) -> Self: + """ + Create an array model from a stored array. + """ + return cls.from_dict(node.metadata.to_dict()) + + def to_stored(self, store_path: StorePath, exists_ok: bool = False) -> Array: + """ + Create a stored version of this array. + """ + # exists_ok kwarg is unhandled until we wire it up to the + # array creation routines + + return Array.from_dict(store_path=store_path, data=self.to_dict()) + + @classmethod + def from_array( + cls: type[Self], + data: NDBuffer, + *, + chunk_grid: ChunkGrid | Literal["auto"] = "auto", + chunk_key_encoding: ChunkKeyEncoding | Literal["auto"] = "auto", + fill_value: Any | Literal["auto"] = "auto", + codecs: CodecPipeline | Literal["auto"] = "auto", + attributes: dict[str, Any] | Literal["auto"] = "auto", + dimension_names: tuple[str, ...] | Literal["auto"] = "auto", + ) -> Self: + """ + Create an ArrayModel from an array-like object, e.g. a numpy array. + + The returned ArrayModel will use the shape and dtype attributes of the input. + The remaining ArrayModel attributes are exposed by this method as keyword arguments, + which can either be the string "auto", which instructs this method to infer or guess + a value, or a concrete value to use. + """ + shape_out = data.shape + data_type_out = auto_data_type(data) + + if chunk_grid == "auto": + chunk_grid_out = auto_chunk_grid(data) + else: + chunk_grid_out = chunk_grid + + if chunk_key_encoding == "auto": + chunk_key_encoding_out = auto_chunk_key_encoding(data) + else: + chunk_key_encoding_out = chunk_key_encoding + + if fill_value == "auto": + fill_value_out = auto_fill_value(data) + else: + fill_value_out = fill_value + + if codecs == "auto": + codecs_out = auto_codecs(data) + else: + codecs_out = codecs + + if attributes == "auto": + attributes_out = auto_attributes(data) + else: + attributes_out = attributes + + if dimension_names == "auto": + dimension_names_out = auto_dimension_names(data) + else: + dimension_names_out = dimension_names + + return cls( + shape=shape_out, + data_type=data_type_out, + chunk_grid=chunk_grid_out, + chunk_key_encoding=chunk_key_encoding_out, + fill_value=fill_value_out, + codecs=codecs_out, + attributes=attributes_out, + dimension_names=dimension_names_out, + ) + + +@dataclass(frozen=True) +class GroupModel(GroupMetadata): + """ + A model of a Zarr v3 group. + """ + + members: dict[str, GroupModel | ArrayModel] | None = field(default_factory=dict) + + @classmethod + def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Self: + """ + Create a GroupModel from a Group. This function is recursive. The depth of recursion is + controlled by the `depth` argument, which is either None (no depth limit) or a finite natural number + specifying how deep into the hierarchy to parse. + """ + members: dict[str, GroupModel | ArrayModel] = {} + + if depth is None: + new_depth = depth + else: + new_depth = depth - 1 + + if depth == 0: + return cls(**node.metadata.to_dict(), members=None) + + else: + for name, member in node.members: + item_out: ArrayModel | GroupModel + if isinstance(member, Array): + item_out = ArrayModel.from_stored(member) + else: + item_out = GroupModel.from_stored(member, depth=new_depth) + + members[name] = item_out + + return cls(attributes=node.metadata.attributes, members=members) + + # todo: make this async + def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group: + """ + Serialize this GroupModel to storage. + """ + + result = Group.create(store_path, attributes=self.attributes, exists_ok=exists_ok) + if self.members is not None: + for name, member in self.members.items(): + substore = store_path / name + member.to_stored(substore, exists_ok=exists_ok) + return result + + +def to_flat( + node: ArrayModel | GroupModel, root_path: str = "" +) -> dict[str, ArrayModel | GroupModel]: + """ + Generate a dict representation of an ArrayModel or GroupModel, where the hierarchy structure + is represented by the keys of the dict. + """ + result = {} + model_copy: ArrayModel | GroupModel + if isinstance(node, ArrayModel): + # we can remove this if we add a model_copy method + model_copy = ArrayModel( + shape=node.shape, + data_type=node.data_type, + chunk_grid=node.chunk_grid, + chunk_key_encoding=node.chunk_key_encoding, + fill_value=node.fill_value, + codecs=node.codecs, + attributes=node.attributes, + dimension_names=node.dimension_names, + ) + else: + model_copy = GroupModel(attributes=node.attributes, members=None) + if node.members is not None: + for name, value in node.members.items(): + result.update(to_flat(value, "/".join([root_path, name]))) + + result[root_path] = model_copy + # sort by increasing key length + result_sorted_keys = dict(sorted(result.items(), key=lambda v: len(v[0]))) + return result_sorted_keys + + +def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupModel: + """ + Create a GroupModel or ArrayModel from a dict representation. + """ + # minimal check that the keys are valid + invalid_keys = [] + for key in data.keys(): + if key.endswith("/"): + invalid_keys.append(key) + if len(invalid_keys) > 0: + msg = f'Invalid keys {invalid_keys} found in data. Keys may not end with the "/"" character' + raise ValueError(msg) + + if tuple(data.keys()) == ("",) and isinstance(tuple(data.values())[0], ArrayModel): + return tuple(data.values())[0] + else: + return from_flat_group(data) + + +def from_flat_group(data: dict[str, ArrayModel | GroupModel]) -> GroupModel: + """ + Create a GroupModel from a hierarchy represented as a dict with string keys and ArrayModel + or GroupModel values. + """ + root_name = "" + sep = "/" + # arrays that will be members of the returned GroupModel + member_arrays: dict[str, ArrayModel] = {} + # groups, and their members, that will be members of the returned GroupModel. + # this dict is populated by recursively applying `from_flat_group` function. + member_groups: dict[str, GroupModel] = {} + # this dict collects the arrayspecs and groupspecs that belong to one of the members of the + # groupspecs we are constructing. They will later be aggregated in a recursive step that + # populates member_groups + submember_by_parent_name: dict[str, dict[str, ArrayModel | GroupModel]] = {} + # copy the input to ensure that mutations are contained inside this function + data_copy = data.copy() + # Get the root node + try: + # The root node is a GroupModel with the key "" + root_node = data_copy.pop(root_name) + if isinstance(root_node, ArrayModel): + raise ValueError("Got an ArrayModel as the root node. This is invalid.") + except KeyError: + # If a root node was not found, create a default one + root_node = GroupModel(attributes={}, members=None) + + # partition the tree (sans root node) into 2 categories: (arrays, groups + their members). + for key, value in data_copy.items(): + key_parts = key.split(sep) + if key_parts[0] != root_name: + raise ValueError(f'Invalid path: {key} does not start with "{root_name}{sep}".') + + subparent_name = key_parts[1] + if len(key_parts) == 2: + # this is an array or group that belongs to the group we are ultimately returning + if isinstance(value, ArrayModel): + member_arrays[subparent_name] = value + else: + if subparent_name not in submember_by_parent_name: + submember_by_parent_name[subparent_name] = {} + submember_by_parent_name[subparent_name][root_name] = value + else: + # these are groups or arrays that belong to one of the member groups + # not great that we repeat this conditional dict initialization + if subparent_name not in submember_by_parent_name: + submember_by_parent_name[subparent_name] = {} + submember_by_parent_name[subparent_name][sep.join([root_name, *key_parts[2:]])] = value + + # recurse + for subparent_name, submemb in submember_by_parent_name.items(): + member_groups[subparent_name] = from_flat_group(submemb) + + return GroupModel(members={**member_groups, **member_arrays}, attributes=root_node.attributes) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index 58cc276c29..063962e5a9 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -16,6 +16,7 @@ from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.codecs._v2 import V2Compressor, V2Filters +from zarr.codecs.bytes import BytesCodec if TYPE_CHECKING: from typing import Literal @@ -161,7 +162,7 @@ class ArrayV3Metadata(ArrayMetadata): chunk_key_encoding: ChunkKeyEncoding fill_value: Any codecs: CodecPipeline - attributes: dict[str, Any] = field(default_factory=dict) + attributes: dict[str, JSON] = field(default_factory=dict) dimension_names: tuple[str, ...] | None = None zarr_format: Literal[3] = field(default=3, init=False) node_type: Literal["array"] = field(default="array", init=False) @@ -174,9 +175,9 @@ def __init__( chunk_grid: dict[str, JSON] | ChunkGrid, chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, fill_value: Any, - codecs: Iterable[Codec | JSON], - attributes: None | dict[str, JSON], - dimension_names: None | Iterable[str], + codecs: Iterable[Codec | JSON] = (BytesCodec(),), + attributes: None | dict[str, JSON] = None, + dimension_names: None | Iterable[str] = None, ) -> None: """ Because the class is a frozen dataclass, we set attributes using object.__setattr__ @@ -266,18 +267,19 @@ def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, JSON]) -> ArrayV3Metadata: + def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self: + data_copy = data.copy() # check that the zarr_format attribute is correct - _ = parse_zarr_format_v3(data.pop("zarr_format")) + _ = parse_zarr_format_v3(data_copy.pop("zarr_format")) # check that the node_type attribute is correct - _ = parse_node_type_array(data.pop("node_type")) + _ = parse_node_type_array(data_copy.pop("node_type")) - data["dimension_names"] = data.pop("dimension_names", None) + data_copy["dimension_names"] = data_copy.pop("dimension_names", None) # TODO: Remove the ignores and use a TypedDict to type `data` - return cls(**data) # type: ignore[arg-type] + return cls(**data_copy) # type: ignore[arg-type] - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, JSON]: out_dict = super().to_dict() if not isinstance(out_dict, dict): @@ -390,11 +392,12 @@ def _json_convert( @classmethod def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: + data_copy = data.copy() # check that the zarr_format attribute is correct - _ = parse_zarr_format_v2(data.pop("zarr_format")) - return cls(**data) + _ = parse_zarr_format_v2(data_copy.pop("zarr_format")) + return cls(**data_copy) - def to_dict(self) -> JSON: + def to_dict(self) -> dict[str, JSON]: zarray_dict = super().to_dict() assert isinstance(zarray_dict, dict) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 36b82f413c..4e29346eeb 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -219,13 +219,13 @@ async def test_asyncgroup_open_wrong_format( {"zarr_format": 2, "attributes": {"foo": 100}}, ), ) -def test_asyncgroup_from_dict(store: MemoryStore | LocalStore, data: dict[str, Any]) -> None: +async def test_asyncgroup_from_dict(store: MemoryStore | LocalStore, data: dict[str, Any]) -> None: """ Test that we can create an AsyncGroup from a dict """ path = "test" store_path = StorePath(store=store, path=path) - group = AsyncGroup.from_dict(store_path, data=data) + group = await AsyncGroup.from_dict(store_path, data=data) assert group.metadata.zarr_format == data["zarr_format"] assert group.metadata.attributes == data["attributes"] diff --git a/tests/v3/test_hierarchy.py b/tests/v3/test_hierarchy.py new file mode 100644 index 0000000000..a13d8e67ab --- /dev/null +++ b/tests/v3/test_hierarchy.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from zarr.array import Array +from zarr.chunk_grids import RegularChunkGrid +from zarr.chunk_key_encodings import DefaultChunkKeyEncoding +from zarr.group import GroupMetadata +from zarr.hierarchy import ArrayModel, GroupModel, from_flat, to_flat +from zarr.metadata import ArrayV3Metadata +from zarr.store.core import StorePath +from zarr.store.memory import MemoryStore + + +def test_array_model_from_dict() -> None: + array_meta = ArrayV3Metadata( + shape=(10,), + data_type="uint8", + chunk_grid=RegularChunkGrid(chunk_shape=(10,)), + chunk_key_encoding=DefaultChunkKeyEncoding(), + fill_value=0, + attributes={"foo": 10}, + ) + + model = ArrayModel.from_dict(array_meta.to_dict()) + assert model.to_dict() == array_meta.to_dict() + + +def test_array_model_to_stored(memory_store: MemoryStore) -> None: + model = ArrayModel( + shape=(10,), + data_type="uint8", + chunk_grid=RegularChunkGrid(chunk_shape=(10,)), + chunk_key_encoding=DefaultChunkKeyEncoding(), + fill_value=0, + attributes={"foo": 10}, + ) + + array = model.to_stored(store_path=StorePath(store=memory_store)) + assert array.metadata.to_dict() == model.to_dict() + + +def test_array_model_from_stored(memory_store: MemoryStore) -> None: + array_meta = ArrayV3Metadata( + shape=(10,), + data_type="uint8", + chunk_grid=RegularChunkGrid(chunk_shape=(10,)), + chunk_key_encoding=DefaultChunkKeyEncoding(), + fill_value=0, + attributes={"foo": 10}, + ) + + array = Array.from_dict(StorePath(memory_store), array_meta.to_dict()) + array_model = ArrayModel.from_stored(array) + assert array_model.to_dict() == array_meta.to_dict() + + +def test_groupmodel_from_dict() -> None: + group_meta = GroupMetadata(attributes={"foo": "bar"}) + model = GroupModel.from_dict({**group_meta.to_dict(), "members": None}) + assert model.to_dict() == {**group_meta.to_dict(), "members": None} + + +@pytest.mark.parametrize("attributes", ({}, {"foo": 100})) +@pytest.mark.parametrize( + "members", + ( + None, + {}, + { + "foo": ArrayModel( + shape=(100,), + data_type="uint8", + chunk_grid=RegularChunkGrid(chunk_shape=(10,)), + chunk_key_encoding=DefaultChunkKeyEncoding(), + fill_value=0, + attributes={"foo": 10}, + ), + "bar": GroupModel( + attributes={"name": "bar"}, + members={ + "subarray": ArrayModel( + shape=(100,), + data_type="uint8", + chunk_grid=RegularChunkGrid(chunk_shape=(10,)), + chunk_key_encoding=DefaultChunkKeyEncoding(), + fill_value=0, + attributes={"foo": 10}, + ) + }, + ), + }, + ), +) +def test_groupmodel_to_stored( + memory_store: MemoryStore, + attributes: dict[str, int], + members: None | dict[str, ArrayModel | GroupModel], +): + model = GroupModel(attributes=attributes, members=members) + group = model.to_stored(StorePath(memory_store, path="test")) + model_rt = GroupModel.from_stored(group) + assert model_rt.attributes == model.attributes + if members is not None: + assert model_rt.members == model.members + else: + assert model_rt.members == {} + + +@pytest.mark.parametrize( + ("data, expected"), + [ + ( + ArrayModel.from_array(np.arange(10)), + {"": ArrayModel.from_array(np.arange(10))}, + ), + ( + GroupModel( + attributes={"foo": 10}, + members={"a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100})}, + ), + { + "": GroupModel(attributes={"foo": 10}, members=None), + "/a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100}), + }, + ), + ( + GroupModel( + attributes={}, + members={ + "a": GroupModel( + attributes={"foo": 10}, + members={"a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100})}, + ), + "b": ArrayModel.from_array(np.arange(2), attributes={"foo": 3}), + }, + ), + { + "": GroupModel(attributes={}, members=None), + "/a": GroupModel(attributes={"foo": 10}, members=None), + "/a/a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100}), + "/b": ArrayModel.from_array(np.arange(2), attributes={"foo": 3}), + }, + ), + ], +) +def test_flatten_unflatten(data, expected) -> None: + flattened = to_flat(data) + assert flattened == expected + assert from_flat(flattened) == data