Skip to content

Commit 5aa0c17

Browse files
committed
initial hierarchy API
1 parent b1f4c50 commit 5aa0c17

File tree

5 files changed

+246
-5
lines changed

5 files changed

+246
-5
lines changed

src/zarr/chunk_key_encodings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def parse_separator(data: JSON) -> SeparatorLiteral:
2626
@dataclass(frozen=True)
2727
class ChunkKeyEncoding(Metadata):
2828
name: str
29-
separator: SeparatorLiteral = "."
29+
separator: SeparatorLiteral = "/"
3030

3131
def __init__(self, *, separator: SeparatorLiteral) -> None:
3232
separator_parsed = parse_separator(separator)

src/zarr/group.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from collections.abc import AsyncGenerator, Iterable
3333
from typing import Any, Literal
3434

35+
from typing_extensions import Self
36+
3537
logger = logging.getLogger("zarr.group")
3638

3739

@@ -97,7 +99,7 @@ def __init__(self, attributes: dict[str, Any] | None = None, zarr_format: ZarrFo
9799
object.__setattr__(self, "zarr_format", zarr_format_parsed)
98100

99101
@classmethod
100-
def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
102+
def from_dict(cls, data: dict[str, Any]) -> Self:
101103
assert data.pop("node_type", None) in ("group", None)
102104
return cls(**data)
103105

src/zarr/hierarchy.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
Copyright © 2023 Howard Hughes Medical Institute
3+
4+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5+
6+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8+
Neither the name of HHMI nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9+
10+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from typing import Any
16+
17+
from typing_extensions import Self
18+
19+
from zarr.array import Array
20+
from zarr.group import Group, GroupMetadata
21+
from zarr.metadata import ArrayV3Metadata
22+
from zarr.store.core import StorePath
23+
24+
25+
class ArrayModel(ArrayV3Metadata):
26+
"""
27+
A model of a Zarr v3 array.
28+
"""
29+
30+
@classmethod
31+
def from_stored(cls: type[Self], node: Array):
32+
return cls.from_dict(node.metadata.to_dict())
33+
34+
def to_stored(self, store_path: StorePath) -> Array:
35+
return Array.from_dict(store_path=store_path, data=self.to_dict())
36+
37+
38+
class GroupModel(GroupMetadata):
39+
"""
40+
A model of a Zarr v3 group.
41+
"""
42+
43+
members: dict[str, GroupModel | ArrayModel] | None
44+
45+
@classmethod
46+
def from_dict(cls: type[Self], data: dict[str, Any]):
47+
return cls(**data)
48+
49+
@classmethod
50+
def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Self:
51+
"""
52+
Create a GroupModel from a Group. This function is recursive. The depth of recursion is
53+
controlled by the `depth` argument, which is either None (no depth limit) or a finite natural number
54+
specifying how deep into the hierarchy to parse.
55+
"""
56+
members: dict[str, GroupModel | ArrayModel]
57+
58+
if depth is None:
59+
new_depth = depth
60+
else:
61+
new_depth = depth - 1
62+
63+
if depth == 0:
64+
return cls(**node.metadata.to_dict(), members=None)
65+
66+
else:
67+
for name, member in node.members():
68+
if isinstance(member, Array):
69+
item_out = ArrayModel.from_stored(member)
70+
else:
71+
item_out = GroupModel.from_stored(member, depth=new_depth)
72+
73+
members[name] = item_out
74+
75+
return cls(**node.metadata.to_dict(), members=members)
76+
77+
def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group:
78+
"""
79+
Serialize this GroupModel to storage.
80+
"""
81+
82+
result = Group.create(store_path, attributes=self.attributes, exists_ok=exists_ok)
83+
if self.members is not None:
84+
for name, member in self.members.items():
85+
substore = store_path / name
86+
member.to_stored(substore, exists_ok=exists_ok)
87+
return result
88+
89+
90+
def to_flat(
91+
node: ArrayModel | GroupModel, root_path: str = ""
92+
) -> dict[str, ArrayModel | GroupModel]:
93+
result = {}
94+
model_copy: ArrayModel | GroupModel
95+
node_dict = node.to_dict()
96+
if isinstance(node, ArrayModel):
97+
model_copy = ArrayModel(**node_dict)
98+
else:
99+
members = node_dict.pop("members")
100+
model_copy = GroupModel(node_dict)
101+
if members is not None:
102+
for name, value in node.members.items():
103+
result.update(to_flat(value, "/".join([root_path, name])))
104+
105+
result[root_path] = model_copy
106+
# sort by increasing key length
107+
result_sorted_keys = dict(sorted(result.items(), key=lambda v: len(v[0])))
108+
return result_sorted_keys
109+
110+
111+
def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupModel:
112+
# minimal check that the keys are valid
113+
invalid_keys = []
114+
for key in data.keys():
115+
if key.endswith("/"):
116+
invalid_keys.append(key)
117+
if len(invalid_keys) > 0:
118+
msg = f'Invalid keys {invalid_keys} found in data. Keys may not end with the "/"" character'
119+
raise ValueError(msg)
120+
121+
if tuple(data.keys()) == ("",) and isinstance(tuple(data.values())[0], ArrayModel):
122+
return tuple(data.values())[0]
123+
else:
124+
return from_flat_group(data)
125+
126+
127+
def from_flat_group(data: dict[str, ArrayModel | GroupModel]) -> GroupModel:
128+
root_name = ""
129+
sep = "/"
130+
# arrays that will be members of the returned GroupModel
131+
member_arrays: dict[str, ArrayModel] = {}
132+
# groups, and their members, that will be members of the returned GroupModel.
133+
# this dict is populated by recursively applying `from_flat_group` function.
134+
member_groups: dict[str, GroupModel] = {}
135+
# this dict collects the arrayspecs and groupspecs that belong to one of the members of the
136+
# groupspecs we are constructing. They will later be aggregated in a recursive step that
137+
# populates member_groups
138+
submember_by_parent_name: dict[str, dict[str, ArrayModel | GroupModel]] = {}
139+
# copy the input to ensure that mutations are contained inside this function
140+
data_copy = data.copy()
141+
# Get the root node
142+
try:
143+
# The root node is a GroupModel with the key ""
144+
root_node = data_copy.pop(root_name)
145+
if isinstance(root_node, ArrayModel):
146+
raise ValueError("Got an ArrayModel as the root node. This is invalid.")
147+
except KeyError:
148+
# If a root node was not found, create a default one
149+
root_node = GroupModel(attributes={}, members=None)
150+
151+
# partition the tree (sans root node) into 2 categories: (arrays, groups + their members).
152+
for key, value in data_copy.items():
153+
key_parts = key.split(sep)
154+
if key_parts[0] != root_name:
155+
raise ValueError(f'Invalid path: {key} does not start with "{root_name}{sep}".')
156+
157+
subparent_name = key_parts[1]
158+
if len(key_parts) == 2:
159+
# this is an array or group that belongs to the group we are ultimately returning
160+
if isinstance(value, ArrayModel):
161+
member_arrays[subparent_name] = value
162+
else:
163+
if subparent_name not in submember_by_parent_name:
164+
submember_by_parent_name[subparent_name] = {}
165+
submember_by_parent_name[subparent_name][root_name] = value
166+
else:
167+
# these are groups or arrays that belong to one of the member groups
168+
# not great that we repeat this conditional dict initialization
169+
if subparent_name not in submember_by_parent_name:
170+
submember_by_parent_name[subparent_name] = {}
171+
submember_by_parent_name[subparent_name][sep.join([root_name, *key_parts[2:]])] = value
172+
173+
# recurse
174+
for subparent_name, submemb in submember_by_parent_name.items():
175+
member_groups[subparent_name] = from_flat_group(submemb)
176+
177+
return GroupModel(members={**member_groups, **member_arrays}, attributes=root_node.attributes)

src/zarr/metadata.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from zarr.chunk_grids import ChunkGrid, RegularChunkGrid
1717
from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator
1818
from zarr.codecs._v2 import V2Compressor, V2Filters
19+
from zarr.codecs.bytes import BytesCodec
1920

2021
if TYPE_CHECKING:
2122
from typing import Literal
@@ -174,9 +175,9 @@ def __init__(
174175
chunk_grid: dict[str, JSON] | ChunkGrid,
175176
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
176177
fill_value: Any,
177-
codecs: Iterable[Codec | JSON],
178-
attributes: None | dict[str, JSON],
179-
dimension_names: None | Iterable[str],
178+
codecs: Iterable[Codec | JSON] = (BytesCodec(),),
179+
attributes: None | dict[str, JSON] = None,
180+
dimension_names: None | Iterable[str] = None,
180181
) -> None:
181182
"""
182183
Because the class is a frozen dataclass, we set attributes using object.__setattr__

tests/v3/test_hierarchy.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
from zarr.array import Array
4+
from zarr.chunk_grids import RegularChunkGrid
5+
from zarr.chunk_key_encodings import DefaultChunkKeyEncoding
6+
from zarr.group import GroupMetadata
7+
from zarr.hierarchy import ArrayModel, GroupModel
8+
from zarr.metadata import ArrayV3Metadata
9+
from zarr.store.memory import MemoryStore
10+
11+
12+
def test_array_model_from_dict() -> None:
13+
array_meta = ArrayV3Metadata(
14+
shape=(10,),
15+
data_type="uint8",
16+
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
17+
chunk_key_encoding=DefaultChunkKeyEncoding(),
18+
fill_value=0,
19+
attributes={"foo": 10},
20+
)
21+
22+
model = ArrayModel.from_dict(array_meta.to_dict())
23+
assert model.to_dict() == array_meta.to_dict()
24+
25+
26+
def test_array_model_to_stored(memory_store: MemoryStore) -> None:
27+
model = ArrayModel(
28+
shape=(10,),
29+
data_type="uint8",
30+
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
31+
chunk_key_encoding=DefaultChunkKeyEncoding(),
32+
fill_value=0,
33+
attributes={"foo": 10},
34+
)
35+
36+
array = model.to_stored(memory_store)
37+
assert array.metadata.to_dict() == model.to_dict()
38+
39+
40+
def test_array_model_from_stored(memory_store: MemoryStore) -> None:
41+
array_meta = ArrayV3Metadata(
42+
shape=(10,),
43+
data_type="uint8",
44+
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
45+
chunk_key_encoding=DefaultChunkKeyEncoding(),
46+
fill_value=0,
47+
attributes={"foo": 10},
48+
)
49+
50+
array = Array.from_dict(memory_store, array_meta.to_dict())
51+
array_model = ArrayModel.from_stored(array)
52+
assert array_model.to_dict() == array_meta.to_dict()
53+
54+
55+
def test_groupmodel_from_dict() -> None:
56+
group_meta = GroupMetadata(attributes={"foo": "bar"})
57+
model = GroupModel.from_dict({**group_meta.to_dict(), "members": None})
58+
assert model.to_dict() == {**group_meta.to_dict(), "members": None}
59+
60+
61+
def test_groupmodel_to_stored(): ...

0 commit comments

Comments
 (0)