|
| 1 | +""" |
| 2 | +Copyright © 2023 Howard Hughes Medical Institute |
| 3 | +
|
| 4 | +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: |
| 5 | +
|
| 6 | + Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. |
| 7 | + Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. |
| 8 | + Neither the name of HHMI nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. |
| 9 | +
|
| 10 | +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 11 | +""" |
| 12 | + |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from typing import Any |
| 16 | + |
| 17 | +from typing_extensions import Self |
| 18 | + |
| 19 | +from zarr.array import Array |
| 20 | +from zarr.group import Group, GroupMetadata |
| 21 | +from zarr.metadata import ArrayV3Metadata |
| 22 | +from zarr.store.core import StorePath |
| 23 | + |
| 24 | + |
| 25 | +class ArrayModel(ArrayV3Metadata): |
| 26 | + """ |
| 27 | + A model of a Zarr v3 array. |
| 28 | + """ |
| 29 | + |
| 30 | + @classmethod |
| 31 | + def from_stored(cls: type[Self], node: Array): |
| 32 | + return cls.from_dict(node.metadata.to_dict()) |
| 33 | + |
| 34 | + def to_stored(self, store_path: StorePath) -> Array: |
| 35 | + return Array.from_dict(store_path=store_path, data=self.to_dict()) |
| 36 | + |
| 37 | + |
| 38 | +class GroupModel(GroupMetadata): |
| 39 | + """ |
| 40 | + A model of a Zarr v3 group. |
| 41 | + """ |
| 42 | + |
| 43 | + members: dict[str, GroupModel | ArrayModel] | None |
| 44 | + |
| 45 | + @classmethod |
| 46 | + def from_dict(cls: type[Self], data: dict[str, Any]): |
| 47 | + return cls(**data) |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Self: |
| 51 | + """ |
| 52 | + Create a GroupModel from a Group. This function is recursive. The depth of recursion is |
| 53 | + controlled by the `depth` argument, which is either None (no depth limit) or a finite natural number |
| 54 | + specifying how deep into the hierarchy to parse. |
| 55 | + """ |
| 56 | + members: dict[str, GroupModel | ArrayModel] |
| 57 | + |
| 58 | + if depth is None: |
| 59 | + new_depth = depth |
| 60 | + else: |
| 61 | + new_depth = depth - 1 |
| 62 | + |
| 63 | + if depth == 0: |
| 64 | + return cls(**node.metadata.to_dict(), members=None) |
| 65 | + |
| 66 | + else: |
| 67 | + for name, member in node.members(): |
| 68 | + if isinstance(member, Array): |
| 69 | + item_out = ArrayModel.from_stored(member) |
| 70 | + else: |
| 71 | + item_out = GroupModel.from_stored(member, depth=new_depth) |
| 72 | + |
| 73 | + members[name] = item_out |
| 74 | + |
| 75 | + return cls(**node.metadata.to_dict(), members=members) |
| 76 | + |
| 77 | + def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group: |
| 78 | + """ |
| 79 | + Serialize this GroupModel to storage. |
| 80 | + """ |
| 81 | + |
| 82 | + result = Group.create(store_path, attributes=self.attributes, exists_ok=exists_ok) |
| 83 | + if self.members is not None: |
| 84 | + for name, member in self.members.items(): |
| 85 | + substore = store_path / name |
| 86 | + member.to_stored(substore, exists_ok=exists_ok) |
| 87 | + return result |
| 88 | + |
| 89 | + |
| 90 | +def to_flat( |
| 91 | + node: ArrayModel | GroupModel, root_path: str = "" |
| 92 | +) -> dict[str, ArrayModel | GroupModel]: |
| 93 | + result = {} |
| 94 | + model_copy: ArrayModel | GroupModel |
| 95 | + node_dict = node.to_dict() |
| 96 | + if isinstance(node, ArrayModel): |
| 97 | + model_copy = ArrayModel(**node_dict) |
| 98 | + else: |
| 99 | + members = node_dict.pop("members") |
| 100 | + model_copy = GroupModel(node_dict) |
| 101 | + if members is not None: |
| 102 | + for name, value in node.members.items(): |
| 103 | + result.update(to_flat(value, "/".join([root_path, name]))) |
| 104 | + |
| 105 | + result[root_path] = model_copy |
| 106 | + # sort by increasing key length |
| 107 | + result_sorted_keys = dict(sorted(result.items(), key=lambda v: len(v[0]))) |
| 108 | + return result_sorted_keys |
| 109 | + |
| 110 | + |
| 111 | +def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupModel: |
| 112 | + # minimal check that the keys are valid |
| 113 | + invalid_keys = [] |
| 114 | + for key in data.keys(): |
| 115 | + if key.endswith("/"): |
| 116 | + invalid_keys.append(key) |
| 117 | + if len(invalid_keys) > 0: |
| 118 | + msg = f'Invalid keys {invalid_keys} found in data. Keys may not end with the "/"" character' |
| 119 | + raise ValueError(msg) |
| 120 | + |
| 121 | + if tuple(data.keys()) == ("",) and isinstance(tuple(data.values())[0], ArrayModel): |
| 122 | + return tuple(data.values())[0] |
| 123 | + else: |
| 124 | + return from_flat_group(data) |
| 125 | + |
| 126 | + |
| 127 | +def from_flat_group(data: dict[str, ArrayModel | GroupModel]) -> GroupModel: |
| 128 | + root_name = "" |
| 129 | + sep = "/" |
| 130 | + # arrays that will be members of the returned GroupModel |
| 131 | + member_arrays: dict[str, ArrayModel] = {} |
| 132 | + # groups, and their members, that will be members of the returned GroupModel. |
| 133 | + # this dict is populated by recursively applying `from_flat_group` function. |
| 134 | + member_groups: dict[str, GroupModel] = {} |
| 135 | + # this dict collects the arrayspecs and groupspecs that belong to one of the members of the |
| 136 | + # groupspecs we are constructing. They will later be aggregated in a recursive step that |
| 137 | + # populates member_groups |
| 138 | + submember_by_parent_name: dict[str, dict[str, ArrayModel | GroupModel]] = {} |
| 139 | + # copy the input to ensure that mutations are contained inside this function |
| 140 | + data_copy = data.copy() |
| 141 | + # Get the root node |
| 142 | + try: |
| 143 | + # The root node is a GroupModel with the key "" |
| 144 | + root_node = data_copy.pop(root_name) |
| 145 | + if isinstance(root_node, ArrayModel): |
| 146 | + raise ValueError("Got an ArrayModel as the root node. This is invalid.") |
| 147 | + except KeyError: |
| 148 | + # If a root node was not found, create a default one |
| 149 | + root_node = GroupModel(attributes={}, members=None) |
| 150 | + |
| 151 | + # partition the tree (sans root node) into 2 categories: (arrays, groups + their members). |
| 152 | + for key, value in data_copy.items(): |
| 153 | + key_parts = key.split(sep) |
| 154 | + if key_parts[0] != root_name: |
| 155 | + raise ValueError(f'Invalid path: {key} does not start with "{root_name}{sep}".') |
| 156 | + |
| 157 | + subparent_name = key_parts[1] |
| 158 | + if len(key_parts) == 2: |
| 159 | + # this is an array or group that belongs to the group we are ultimately returning |
| 160 | + if isinstance(value, ArrayModel): |
| 161 | + member_arrays[subparent_name] = value |
| 162 | + else: |
| 163 | + if subparent_name not in submember_by_parent_name: |
| 164 | + submember_by_parent_name[subparent_name] = {} |
| 165 | + submember_by_parent_name[subparent_name][root_name] = value |
| 166 | + else: |
| 167 | + # these are groups or arrays that belong to one of the member groups |
| 168 | + # not great that we repeat this conditional dict initialization |
| 169 | + if subparent_name not in submember_by_parent_name: |
| 170 | + submember_by_parent_name[subparent_name] = {} |
| 171 | + submember_by_parent_name[subparent_name][sep.join([root_name, *key_parts[2:]])] = value |
| 172 | + |
| 173 | + # recurse |
| 174 | + for subparent_name, submemb in submember_by_parent_name.items(): |
| 175 | + member_groups[subparent_name] = from_flat_group(submemb) |
| 176 | + |
| 177 | + return GroupModel(members={**member_groups, **member_arrays}, attributes=root_node.attributes) |
0 commit comments