|
6 | 6 | import logging |
7 | 7 | import warnings |
8 | 8 | from collections import defaultdict |
| 9 | +from collections.abc import AsyncIterator |
9 | 10 | from dataclasses import asdict, dataclass, field, fields, replace |
10 | 11 | from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload |
11 | 12 |
|
@@ -1195,6 +1196,37 @@ async def require_array( |
1195 | 1196 |
|
1196 | 1197 | return ds |
1197 | 1198 |
|
| 1199 | + async def create_nodes( |
| 1200 | + self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] |
| 1201 | + ) -> tuple[tuple[str, AsyncGroup | AsyncArray]]: |
| 1202 | + """ |
| 1203 | + Create a set of arrays or groups rooted at this group. |
| 1204 | + """ |
| 1205 | + _nodes: ( |
| 1206 | + dict[str, GroupMetadata | ArrayV3Metadata] | dict[str, GroupMetadata | ArrayV2Metadata] |
| 1207 | + ) |
| 1208 | + match self.metadata.zarr_format: |
| 1209 | + case 2: |
| 1210 | + if not all( |
| 1211 | + isinstance(node, ArrayV2Metadata | GroupMetadata) for node in nodes.values() |
| 1212 | + ): |
| 1213 | + raise ValueError("Only v2 arrays and groups are supported") |
| 1214 | + _nodes = cast(dict[str, ArrayV2Metadata | GroupMetadata], nodes) |
| 1215 | + return await create_nodes_v2( |
| 1216 | + store=self.store_path.store, path=self.path, nodes=_nodes |
| 1217 | + ) |
| 1218 | + case 3: |
| 1219 | + if not all( |
| 1220 | + isinstance(node, ArrayV3Metadata | GroupMetadata) for node in nodes.values() |
| 1221 | + ): |
| 1222 | + raise ValueError("Only v3 arrays and groups are supported") |
| 1223 | + _nodes = cast(dict[str, ArrayV3Metadata | GroupMetadata], nodes) |
| 1224 | + return await create_nodes_v3( |
| 1225 | + store=self.store_path.store, path=self.path, nodes=_nodes |
| 1226 | + ) |
| 1227 | + case _: |
| 1228 | + raise ValueError(f"Unsupported zarr format: {self.metadata.zarr_format}") |
| 1229 | + |
1198 | 1230 | async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: |
1199 | 1231 | """Update group attributes. |
1200 | 1232 |
|
@@ -2627,3 +2659,71 @@ def array( |
2627 | 2659 | ) |
2628 | 2660 | ) |
2629 | 2661 | ) |
| 2662 | + |
| 2663 | + |
| 2664 | +async def _save_metadata_return_node( |
| 2665 | + node: AsyncArray[Any] | AsyncGroup, |
| 2666 | +) -> AsyncArray[Any] | AsyncGroup: |
| 2667 | + if isinstance(node, AsyncArray): |
| 2668 | + await node._save_metadata(node.metadata, ensure_parents=False) |
| 2669 | + else: |
| 2670 | + await node._save_metadata(ensure_parents=False) |
| 2671 | + return node |
| 2672 | + |
| 2673 | + |
| 2674 | +async def create_nodes_v2( |
| 2675 | + *, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata] |
| 2676 | +) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata]]]: ... |
| 2677 | + |
| 2678 | + |
| 2679 | +async def create_nodes( |
| 2680 | + *, store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] |
| 2681 | +) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]: |
| 2682 | + """ |
| 2683 | + Create a collection of arrays and groups concurrently and atomically. To ensure atomicity, |
| 2684 | + no attempt is made to ensure that intermediate groups are created. |
| 2685 | + """ |
| 2686 | + create_tasks = [] |
| 2687 | + for key, value in nodes.items(): |
| 2688 | + new_store_path = store_path / key |
| 2689 | + node: AsyncArray[Any] | AsyncGroup |
| 2690 | + match value: |
| 2691 | + case ArrayV3Metadata() | ArrayV2Metadata(): |
| 2692 | + node = AsyncArray(value, store_path=new_store_path) |
| 2693 | + case GroupMetadata(): |
| 2694 | + node = AsyncGroup(value, store_path=new_store_path) |
| 2695 | + case _: |
| 2696 | + raise ValueError(f"Unexpected metadata type {type(value)}") |
| 2697 | + create_tasks.append(_save_metadata_return_node(node)) |
| 2698 | + for coro in asyncio.as_completed(create_tasks): |
| 2699 | + yield await coro |
| 2700 | + |
| 2701 | + |
| 2702 | +T = TypeVar("T") |
| 2703 | + |
| 2704 | + |
| 2705 | +def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]: |
| 2706 | + """ |
| 2707 | + Given a dict of {string: T} pairs, where the keys are strings separated by some separator, |
| 2708 | + return the result of splitting each key with the separator. |
| 2709 | +
|
| 2710 | + Parameters |
| 2711 | + ---------- |
| 2712 | + data : dict[str, T] |
| 2713 | + A dict of {string:, T} pairs. |
| 2714 | +
|
| 2715 | + Returns |
| 2716 | + ------- |
| 2717 | + dict[tuple[str,...], T] |
| 2718 | + The same values, but the keys have been split and converted to tuples. |
| 2719 | +
|
| 2720 | + Examples |
| 2721 | + -------- |
| 2722 | + >>> _tuplize_tree({"a": 1}, separator='/') |
| 2723 | + {("a",): 1} |
| 2724 | +
|
| 2725 | + >>> _tuplize_tree({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/') |
| 2726 | + {("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3} |
| 2727 | + """ |
| 2728 | + |
| 2729 | + return {tuple(k.split(separator)): v for k, v in data.items()} |
0 commit comments