Skip to content

Commit 3bf098a

Browse files
committed
add ArrayModel.from_array, await unawaited awaitables, add flattening tests
1 parent c3cf284 commit 3bf098a

File tree

4 files changed

+227
-12
lines changed

4 files changed

+227
-12
lines changed

src/zarr/group.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ async def open(
183183
assert zarr_json_bytes is not None
184184
group_metadata = json.loads(zarr_json_bytes.to_bytes())
185185

186-
return cls.from_dict(store_path, group_metadata)
186+
return await cls.from_dict(store_path, group_metadata)
187187

188188
@classmethod
189-
def from_dict(
189+
async def from_dict(
190190
cls,
191191
store_path: StorePath,
192192
data: dict[str, Any],
@@ -217,9 +217,9 @@ async def getitem(
217217
else:
218218
zarr_json = json.loads(zarr_json_bytes.to_bytes())
219219
if zarr_json["node_type"] == "group":
220-
return type(self).from_dict(store_path, zarr_json)
220+
return await type(self).from_dict(store_path, zarr_json)
221221
elif zarr_json["node_type"] == "array":
222-
return sync(AsyncArray.from_dict(store_path, zarr_json))
222+
return await AsyncArray.from_dict(store_path, zarr_json)
223223
else:
224224
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
225225
elif self.metadata.zarr_format == 2:
@@ -250,7 +250,7 @@ async def getitem(
250250
else {"zarr_format": self.metadata.zarr_format}
251251
)
252252
zarr_json = {**zgroup, "attributes": zattrs}
253-
return type(self).from_dict(store_path, zarr_json)
253+
return await type(self).from_dict(store_path, zarr_json)
254254
else:
255255
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
256256

src/zarr/hierarchy.py

Lines changed: 174 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,111 @@
1313
from __future__ import annotations
1414

1515
from dataclasses import dataclass, field
16+
from typing import Any, Literal
1617

18+
import numpy as np
1719
from typing_extensions import Self
1820

21+
from zarr.abc.codec import CodecPipeline
1922
from zarr.array import Array
23+
from zarr.buffer import NDBuffer
24+
from zarr.chunk_grids import ChunkGrid, RegularChunkGrid
25+
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding
26+
from zarr.codecs.bytes import BytesCodec
2027
from zarr.group import Group, GroupMetadata
2128
from zarr.metadata import ArrayV3Metadata
2229
from zarr.store.core import StorePath
30+
from zarr.v2.util import guess_chunks
31+
32+
33+
def auto_data_type(data: Any) -> Any:
34+
if hasattr(data, "dtype"):
35+
if hasattr(data, "data_type"):
36+
msg = (
37+
f"Could not infer the data_type attribute from {data}, because "
38+
"it has both `dtype` and `data_type` attributes. "
39+
"This method requires input with one, or the other, of these attributes."
40+
)
41+
raise ValueError(msg)
42+
return data.dtype
43+
elif hasattr(data, "data_type") and not hasattr(data, "dtype"):
44+
return data.data_type
45+
else:
46+
msg = (
47+
f"Could not infer the data_type attribute from {data}. "
48+
"Expected either an object with a `dtype` attribute, "
49+
"or an object with a `data_type` attribute."
50+
)
51+
raise ValueError(msg)
52+
53+
54+
def auto_attributes(data: Any) -> Any:
55+
"""
56+
Guess attributes from:
57+
input with an `attrs` attribute, or
58+
input with an `attributes` attribute,
59+
or anything (returning {})
60+
"""
61+
if hasattr(data, "attrs"):
62+
return data.attrs
63+
if hasattr(data, "attributes"):
64+
return data.attributes
65+
return {}
66+
67+
68+
def auto_chunk_key_encoding(data: Any) -> Any:
69+
if hasattr(data, "chunk_key_encoding"):
70+
return data.chunk_key_encoding
71+
return DefaultChunkKeyEncoding()
72+
73+
74+
def auto_fill_value(data: Any) -> Any:
75+
"""
76+
Guess fill value from an input with a `fill_value` attribute, returning 0 otherwise.
77+
"""
78+
if hasattr(data, "fill_value"):
79+
return data.fill_value
80+
return 0
81+
82+
83+
def auto_codecs(data: Any) -> Any:
84+
"""
85+
Guess compressor from an input with a `compressor` attribute, returning `None` otherwise.
86+
"""
87+
if hasattr(data, "codecs"):
88+
return data.codecs
89+
return (BytesCodec(),)
90+
91+
92+
def auto_dimension_names(data: Any) -> Any:
93+
"""
94+
If the input has a `dimension_names` attribute, return it, otherwise
95+
return None.
96+
"""
97+
98+
if hasattr(data, "dimension_names"):
99+
return data.dimension_names
100+
return None
101+
102+
103+
def auto_chunk_grid(data: Any) -> Any:
104+
"""
105+
Guess a chunk grid from:
106+
input with a `chunk_grid` attribute,
107+
input with a `chunksize` attribute, or
108+
input with a `chunks` attribute, or,
109+
input with `shape` and `dtype` attributes
110+
"""
111+
if hasattr(data, "chunk_grid"):
112+
# more a statement of intent than anything else
113+
return data.chunk_grid
114+
if hasattr(data, "chunksize"):
115+
chunks = data.chunksize
116+
elif hasattr(data, "chunks"):
117+
chunks = data.chunks
118+
else:
119+
chunks = guess_chunks(data.shape, np.dtype(data.dtype).itemsize)
120+
return RegularChunkGrid(chunk_shape=chunks)
23121

24122

25123
class ArrayModel(ArrayV3Metadata):
@@ -43,6 +141,70 @@ def to_stored(self, store_path: StorePath, exists_ok: bool = False) -> Array:
43141

44142
return Array.from_dict(store_path=store_path, data=self.to_dict())
45143

144+
@classmethod
145+
def from_array(
146+
cls: type[Self],
147+
data: NDBuffer,
148+
*,
149+
chunk_grid: ChunkGrid | Literal["auto"] = "auto",
150+
chunk_key_encoding: ChunkKeyEncoding | Literal["auto"] = "auto",
151+
fill_value: Any | Literal["auto"] = "auto",
152+
codecs: CodecPipeline | Literal["auto"] = "auto",
153+
attributes: dict[str, Any] | Literal["auto"] = "auto",
154+
dimension_names: tuple[str, ...] | Literal["auto"] = "auto",
155+
) -> Self:
156+
"""
157+
Create an ArrayModel from an array-like object, e.g. a numpy array.
158+
159+
The returned ArrayModel will use the shape and dtype attributes of the input.
160+
The remaining ArrayModel attributes are exposed by this method as keyword arguments,
161+
which can either be the string "auto", which instructs this method to infer or guess
162+
a value, or a concrete value to use.
163+
"""
164+
shape_out = data.shape
165+
data_type_out = auto_data_type(data)
166+
167+
if chunk_grid == "auto":
168+
chunk_grid_out = auto_chunk_grid(data)
169+
else:
170+
chunk_grid_out = chunk_grid
171+
172+
if chunk_key_encoding == "auto":
173+
chunk_key_encoding_out = auto_chunk_key_encoding(data)
174+
else:
175+
chunk_key_encoding_out = chunk_key_encoding
176+
177+
if fill_value == "auto":
178+
fill_value_out = auto_fill_value(data)
179+
else:
180+
fill_value_out = fill_value
181+
182+
if codecs == "auto":
183+
codecs_out = auto_codecs(data)
184+
else:
185+
codecs_out = codecs
186+
187+
if attributes == "auto":
188+
attributes_out = auto_attributes(data)
189+
else:
190+
attributes_out = attributes
191+
192+
if dimension_names == "auto":
193+
dimension_names_out = auto_dimension_names(data)
194+
else:
195+
dimension_names_out = dimension_names
196+
197+
return cls(
198+
shape=shape_out,
199+
data_type=data_type_out,
200+
chunk_grid=chunk_grid_out,
201+
chunk_key_encoding=chunk_key_encoding_out,
202+
fill_value=fill_value_out,
203+
codecs=codecs_out,
204+
attributes=attributes_out,
205+
dimension_names=dimension_names_out,
206+
)
207+
46208

47209
@dataclass(frozen=True)
48210
class GroupModel(GroupMetadata):
@@ -104,11 +266,20 @@ def to_flat(
104266
"""
105267
result = {}
106268
model_copy: ArrayModel | GroupModel
107-
node_dict = node.to_dict()
108269
if isinstance(node, ArrayModel):
109-
model_copy = ArrayModel(**node_dict)
270+
# we can remove this if we add a model_copy method
271+
model_copy = ArrayModel(
272+
shape=node.shape,
273+
data_type=node.data_type,
274+
chunk_grid=node.chunk_grid,
275+
chunk_key_encoding=node.chunk_key_encoding,
276+
fill_value=node.fill_value,
277+
codecs=node.codecs,
278+
attributes=node.attributes,
279+
dimension_names=node.dimension_names,
280+
)
110281
else:
111-
model_copy = GroupModel(**node_dict)
282+
model_copy = GroupModel(attributes=node.attributes, members=None)
112283
if node.members is not None:
113284
for name, value in node.members.items():
114285
result.update(to_flat(value, "/".join([root_path, name])))

src/zarr/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class ArrayV3Metadata(ArrayMetadata):
162162
chunk_key_encoding: ChunkKeyEncoding
163163
fill_value: Any
164164
codecs: CodecPipeline
165-
attributes: dict[str, Any] = field(default_factory=dict)
165+
attributes: dict[str, JSON] = field(default_factory=dict)
166166
dimension_names: tuple[str, ...] | None = None
167167
zarr_format: Literal[3] = field(default=3, init=False)
168168
node_type: Literal["array"] = field(default="array", init=False)

tests/v3/test_hierarchy.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pytest
45

56
from zarr.array import Array
67
from zarr.chunk_grids import RegularChunkGrid
78
from zarr.chunk_key_encodings import DefaultChunkKeyEncoding
89
from zarr.group import GroupMetadata
9-
from zarr.hierarchy import ArrayModel, GroupModel
10+
from zarr.hierarchy import ArrayModel, GroupModel, from_flat, to_flat
1011
from zarr.metadata import ArrayV3Metadata
1112
from zarr.store.core import StorePath
1213
from zarr.store.memory import MemoryStore
@@ -36,7 +37,7 @@ def test_array_model_to_stored(memory_store: MemoryStore) -> None:
3637
attributes={"foo": 10},
3738
)
3839

39-
array = model.to_stored(memory_store)
40+
array = model.to_stored(store_path=StorePath(store=memory_store))
4041
assert array.metadata.to_dict() == model.to_dict()
4142

4243

@@ -50,7 +51,7 @@ def test_array_model_from_stored(memory_store: MemoryStore) -> None:
5051
attributes={"foo": 10},
5152
)
5253

53-
array = Array.from_dict(memory_store, array_meta.to_dict())
54+
array = Array.from_dict(StorePath(memory_store), array_meta.to_dict())
5455
array_model = ArrayModel.from_stored(array)
5556
assert array_model.to_dict() == array_meta.to_dict()
5657

@@ -105,3 +106,46 @@ def test_groupmodel_to_stored(
105106
assert model_rt.members == model.members
106107
else:
107108
assert model_rt.members == {}
109+
110+
111+
@pytest.mark.parametrize(
112+
("data, expected"),
113+
[
114+
(
115+
ArrayModel.from_array(np.arange(10)),
116+
{"": ArrayModel.from_array(np.arange(10))},
117+
),
118+
(
119+
GroupModel(
120+
attributes={"foo": 10},
121+
members={"a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100})},
122+
),
123+
{
124+
"": GroupModel(attributes={"foo": 10}, members=None),
125+
"/a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100}),
126+
},
127+
),
128+
(
129+
GroupModel(
130+
attributes={},
131+
members={
132+
"a": GroupModel(
133+
attributes={"foo": 10},
134+
members={"a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100})},
135+
),
136+
"b": ArrayModel.from_array(np.arange(2), attributes={"foo": 3}),
137+
},
138+
),
139+
{
140+
"": GroupModel(attributes={}, members=None),
141+
"/a": GroupModel(attributes={"foo": 10}, members=None),
142+
"/a/a": ArrayModel.from_array(np.arange(5), attributes={"foo": 100}),
143+
"/b": ArrayModel.from_array(np.arange(2), attributes={"foo": 3}),
144+
},
145+
),
146+
],
147+
)
148+
def test_flatten_unflatten(data, expected) -> None:
149+
flattened = to_flat(data)
150+
assert flattened == expected
151+
assert from_flat(flattened) == data

0 commit comments

Comments
 (0)