Skip to content

Commit 94a60ae

Browse files
committed
start recursive ser/deseriaization tests, fix issues in the Array API
1 parent 5aa0c17 commit 94a60ae

File tree

7 files changed

+103
-39
lines changed

7 files changed

+103
-39
lines changed

src/zarr/abc/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@dataclass(frozen=True)
1515
class Metadata:
16-
def to_dict(self) -> JSON:
16+
def to_dict(self) -> dict[str, JSON]:
1717
"""
1818
Recursively serialize this model to a dictionary.
1919
This method inspects the fields of self and calls `x.to_dict()` for any fields that
@@ -37,7 +37,7 @@ def to_dict(self) -> JSON:
3737
return out_dict
3838

3939
@classmethod
40-
def from_dict(cls, data: dict[str, JSON]) -> Self:
40+
def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self:
4141
"""
4242
Create an instance of the model from a dictionary
4343
"""

src/zarr/array.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,13 @@ async def _create_v2(
273273
return array
274274

275275
@classmethod
276-
def from_dict(
277-
cls,
278-
store_path: StorePath,
279-
data: dict[str, JSON],
276+
async def from_dict(
277+
cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None
280278
) -> AsyncArray:
281-
metadata = parse_array_metadata(data)
282-
async_array = cls(metadata=metadata, store_path=store_path)
279+
data_parsed = parse_array_metadata(data)
280+
async_array = cls(metadata=data_parsed, store_path=store_path, order=order)
281+
# weird that this method doesn't use the metadata attribute
282+
await async_array._save_metadata(async_array.metadata)
283283
return async_array
284284

285285
@classmethod
@@ -535,11 +535,9 @@ def create(
535535

536536
@classmethod
537537
def from_dict(
538-
cls,
539-
store_path: StorePath,
540-
data: dict[str, JSON],
538+
cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None
541539
) -> Array:
542-
async_array = AsyncArray.from_dict(store_path=store_path, data=data)
540+
async_array = sync(AsyncArray.from_dict(store_path=store_path, data=data))
543541
return cls(async_array)
544542

545543
@classmethod

src/zarr/codecs/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def from_dict(cls, data: Iterable[JSON | Codec], *, batch_size: int | None = Non
8484
out.append(get_codec_class(name_parsed).from_dict(c)) # type: ignore[arg-type]
8585
return cls.from_list(out, batch_size=batch_size)
8686

87-
def to_dict(self) -> JSON:
87+
def to_dict(self) -> list[JSON]:
8888
return [c.to_dict() for c in self]
8989

9090
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:

src/zarr/group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def getitem(
219219
if zarr_json["node_type"] == "group":
220220
return type(self).from_dict(store_path, zarr_json)
221221
elif zarr_json["node_type"] == "array":
222-
return AsyncArray.from_dict(store_path, zarr_json)
222+
return sync(AsyncArray.from_dict(store_path, zarr_json))
223223
else:
224224
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
225225
elif self.metadata.zarr_format == 2:
@@ -242,7 +242,7 @@ async def getitem(
242242
if zarray is not None:
243243
# TODO: update this once the V2 array support is part of the primary array class
244244
zarr_json = {**zarray, "attributes": zattrs}
245-
return AsyncArray.from_dict(store_path, zarray)
245+
return sync(AsyncArray.from_dict(store_path, zarray))
246246
else:
247247
zgroup = (
248248
json.loads(zgroup_bytes.to_bytes())

src/zarr/hierarchy.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from __future__ import annotations
1414

15-
from typing import Any
15+
from dataclasses import dataclass, field
1616

1717
from typing_extensions import Self
1818

@@ -28,23 +28,29 @@ class ArrayModel(ArrayV3Metadata):
2828
"""
2929

3030
@classmethod
31-
def from_stored(cls: type[Self], node: Array):
31+
def from_stored(cls: type[Self], node: Array) -> Self:
32+
"""
33+
Create an array model from a stored array.
34+
"""
3235
return cls.from_dict(node.metadata.to_dict())
3336

34-
def to_stored(self, store_path: StorePath) -> Array:
37+
def to_stored(self, store_path: StorePath, exists_ok: bool = False) -> Array:
38+
"""
39+
Create a stored version of this array.
40+
"""
41+
# exists_ok kwarg is unhandled until we wire it up to the
42+
# array creation routines
43+
3544
return Array.from_dict(store_path=store_path, data=self.to_dict())
3645

3746

47+
@dataclass(frozen=True)
3848
class GroupModel(GroupMetadata):
3949
"""
4050
A model of a Zarr v3 group.
4151
"""
4252

43-
members: dict[str, GroupModel | ArrayModel] | None
44-
45-
@classmethod
46-
def from_dict(cls: type[Self], data: dict[str, Any]):
47-
return cls(**data)
53+
members: dict[str, GroupModel | ArrayModel] | None = field(default_factory=dict)
4854

4955
@classmethod
5056
def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Self:
@@ -53,7 +59,7 @@ def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Se
5359
controlled by the `depth` argument, which is either None (no depth limit) or a finite natural number
5460
specifying how deep into the hierarchy to parse.
5561
"""
56-
members: dict[str, GroupModel | ArrayModel]
62+
members: dict[str, GroupModel | ArrayModel] = {}
5763

5864
if depth is None:
5965
new_depth = depth
@@ -64,16 +70,18 @@ def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Se
6470
return cls(**node.metadata.to_dict(), members=None)
6571

6672
else:
67-
for name, member in node.members():
73+
for name, member in node.members:
74+
item_out: ArrayModel | GroupModel
6875
if isinstance(member, Array):
6976
item_out = ArrayModel.from_stored(member)
7077
else:
7178
item_out = GroupModel.from_stored(member, depth=new_depth)
7279

7380
members[name] = item_out
7481

75-
return cls(**node.metadata.to_dict(), members=members)
82+
return cls(attributes=node.metadata.attributes, members=members)
7683

84+
# todo: make this async
7785
def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group:
7886
"""
7987
Serialize this GroupModel to storage.
@@ -90,15 +98,18 @@ def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group:
9098
def to_flat(
9199
node: ArrayModel | GroupModel, root_path: str = ""
92100
) -> dict[str, ArrayModel | GroupModel]:
101+
"""
102+
Generate a dict representation of an ArrayModel or GroupModel, where the hierarchy structure
103+
is represented by the keys of the dict.
104+
"""
93105
result = {}
94106
model_copy: ArrayModel | GroupModel
95107
node_dict = node.to_dict()
96108
if isinstance(node, ArrayModel):
97109
model_copy = ArrayModel(**node_dict)
98110
else:
99-
members = node_dict.pop("members")
100-
model_copy = GroupModel(node_dict)
101-
if members is not None:
111+
model_copy = GroupModel(**node_dict)
112+
if node.members is not None:
102113
for name, value in node.members.items():
103114
result.update(to_flat(value, "/".join([root_path, name])))
104115

@@ -109,6 +120,9 @@ def to_flat(
109120

110121

111122
def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupModel:
123+
"""
124+
Create a GroupModel or ArrayModel from a dict representation.
125+
"""
112126
# minimal check that the keys are valid
113127
invalid_keys = []
114128
for key in data.keys():
@@ -125,6 +139,10 @@ def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupMod
125139

126140

127141
def from_flat_group(data: dict[str, ArrayModel | GroupModel]) -> GroupModel:
142+
"""
143+
Create a GroupModel from a hierarchy represented as a dict with string keys and ArrayModel
144+
or GroupModel values.
145+
"""
128146
root_name = ""
129147
sep = "/"
130148
# arrays that will be members of the returned GroupModel

src/zarr/metadata.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,18 +267,19 @@ def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
267267
}
268268

269269
@classmethod
270-
def from_dict(cls, data: dict[str, JSON]) -> ArrayV3Metadata:
270+
def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self:
271+
data_copy = data.copy()
271272
# check that the zarr_format attribute is correct
272-
_ = parse_zarr_format_v3(data.pop("zarr_format"))
273+
_ = parse_zarr_format_v3(data_copy.pop("zarr_format"))
273274
# check that the node_type attribute is correct
274-
_ = parse_node_type_array(data.pop("node_type"))
275+
_ = parse_node_type_array(data_copy.pop("node_type"))
275276

276-
data["dimension_names"] = data.pop("dimension_names", None)
277+
data_copy["dimension_names"] = data_copy.pop("dimension_names", None)
277278

278279
# TODO: Remove the ignores and use a TypedDict to type `data`
279-
return cls(**data) # type: ignore[arg-type]
280+
return cls(**data_copy) # type: ignore[arg-type]
280281

281-
def to_dict(self) -> dict[str, Any]:
282+
def to_dict(self) -> dict[str, JSON]:
282283
out_dict = super().to_dict()
283284

284285
if not isinstance(out_dict, dict):
@@ -391,11 +392,12 @@ def _json_convert(
391392

392393
@classmethod
393394
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
395+
data_copy = data.copy()
394396
# check that the zarr_format attribute is correct
395-
_ = parse_zarr_format_v2(data.pop("zarr_format"))
396-
return cls(**data)
397+
_ = parse_zarr_format_v2(data_copy.pop("zarr_format"))
398+
return cls(**data_copy)
397399

398-
def to_dict(self) -> JSON:
400+
def to_dict(self) -> dict[str, JSON]:
399401
zarray_dict = super().to_dict()
400402

401403
assert isinstance(zarray_dict, dict)

tests/v3/test_hierarchy.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
from zarr.array import Array
46
from zarr.chunk_grids import RegularChunkGrid
57
from zarr.chunk_key_encodings import DefaultChunkKeyEncoding
68
from zarr.group import GroupMetadata
79
from zarr.hierarchy import ArrayModel, GroupModel
810
from zarr.metadata import ArrayV3Metadata
11+
from zarr.store.core import StorePath
912
from zarr.store.memory import MemoryStore
1013

1114

@@ -58,4 +61,47 @@ def test_groupmodel_from_dict() -> None:
5861
assert model.to_dict() == {**group_meta.to_dict(), "members": None}
5962

6063

61-
def test_groupmodel_to_stored(): ...
64+
@pytest.mark.parametrize("attributes", ({}, {"foo": 100}))
65+
@pytest.mark.parametrize(
66+
"members",
67+
(
68+
None,
69+
{},
70+
{
71+
"foo": ArrayModel(
72+
shape=(100,),
73+
data_type="uint8",
74+
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
75+
chunk_key_encoding=DefaultChunkKeyEncoding(),
76+
fill_value=0,
77+
attributes={"foo": 10},
78+
),
79+
"bar": GroupModel(
80+
attributes={"name": "bar"},
81+
members={
82+
"subarray": ArrayModel(
83+
shape=(100,),
84+
data_type="uint8",
85+
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
86+
chunk_key_encoding=DefaultChunkKeyEncoding(),
87+
fill_value=0,
88+
attributes={"foo": 10},
89+
)
90+
},
91+
),
92+
},
93+
),
94+
)
95+
def test_groupmodel_to_stored(
96+
memory_store: MemoryStore,
97+
attributes: dict[str, int],
98+
members: None | dict[str, ArrayModel | GroupModel],
99+
):
100+
model = GroupModel(attributes=attributes, members=members)
101+
group = model.to_stored(StorePath(memory_store, path="test"))
102+
model_rt = GroupModel.from_stored(group)
103+
assert model_rt.attributes == model.attributes
104+
if members is not None:
105+
assert model_rt.members == model.members
106+
else:
107+
assert model_rt.members == {}

0 commit comments

Comments
 (0)