Skip to content

Commit 8faf994

Browse files
committed
sketch out batch creation routine
1 parent 1a75957 commit 8faf994

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

src/zarr/core/group.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import warnings
88
from collections import defaultdict
9+
from collections.abc import AsyncIterator
910
from dataclasses import asdict, dataclass, field, fields, replace
1011
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
1112

@@ -1195,6 +1196,37 @@ async def require_array(
11951196

11961197
return ds
11971198

1199+
async def create_nodes(
1200+
self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]
1201+
) -> tuple[tuple[str, AsyncGroup | AsyncArray]]:
1202+
"""
1203+
Create a set of arrays or groups rooted at this group.
1204+
"""
1205+
_nodes: (
1206+
dict[str, GroupMetadata | ArrayV3Metadata] | dict[str, GroupMetadata | ArrayV2Metadata]
1207+
)
1208+
match self.metadata.zarr_format:
1209+
case 2:
1210+
if not all(
1211+
isinstance(node, ArrayV2Metadata | GroupMetadata) for node in nodes.values()
1212+
):
1213+
raise ValueError("Only v2 arrays and groups are supported")
1214+
_nodes = cast(dict[str, ArrayV2Metadata | GroupMetadata], nodes)
1215+
return await create_nodes_v2(
1216+
store=self.store_path.store, path=self.path, nodes=_nodes
1217+
)
1218+
case 3:
1219+
if not all(
1220+
isinstance(node, ArrayV3Metadata | GroupMetadata) for node in nodes.values()
1221+
):
1222+
raise ValueError("Only v3 arrays and groups are supported")
1223+
_nodes = cast(dict[str, ArrayV3Metadata | GroupMetadata], nodes)
1224+
return await create_nodes_v3(
1225+
store=self.store_path.store, path=self.path, nodes=_nodes
1226+
)
1227+
case _:
1228+
raise ValueError(f"Unsupported zarr format: {self.metadata.zarr_format}")
1229+
11981230
async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
11991231
"""Update group attributes.
12001232
@@ -2627,3 +2659,71 @@ def array(
26272659
)
26282660
)
26292661
)
2662+
2663+
2664+
async def _save_metadata_return_node(
2665+
node: AsyncArray[Any] | AsyncGroup,
2666+
) -> AsyncArray[Any] | AsyncGroup:
2667+
if isinstance(node, AsyncArray):
2668+
await node._save_metadata(node.metadata, ensure_parents=False)
2669+
else:
2670+
await node._save_metadata(ensure_parents=False)
2671+
return node
2672+
2673+
2674+
async def create_nodes_v2(
2675+
*, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata]
2676+
) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata]]]: ...
2677+
2678+
2679+
async def create_nodes(
2680+
*, store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
2681+
) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]:
2682+
"""
2683+
Create a collection of arrays and groups concurrently and atomically. To ensure atomicity,
2684+
no attempt is made to ensure that intermediate groups are created.
2685+
"""
2686+
create_tasks = []
2687+
for key, value in nodes.items():
2688+
new_store_path = store_path / key
2689+
node: AsyncArray[Any] | AsyncGroup
2690+
match value:
2691+
case ArrayV3Metadata() | ArrayV2Metadata():
2692+
node = AsyncArray(value, store_path=new_store_path)
2693+
case GroupMetadata():
2694+
node = AsyncGroup(value, store_path=new_store_path)
2695+
case _:
2696+
raise ValueError(f"Unexpected metadata type {type(value)}")
2697+
create_tasks.append(_save_metadata_return_node(node))
2698+
for coro in asyncio.as_completed(create_tasks):
2699+
yield await coro
2700+
2701+
2702+
T = TypeVar("T")
2703+
2704+
2705+
def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]:
2706+
"""
2707+
Given a dict of {string: T} pairs, where the keys are strings separated by some separator,
2708+
return the result of splitting each key with the separator.
2709+
2710+
Parameters
2711+
----------
2712+
data : dict[str, T]
2713+
A dict of {string:, T} pairs.
2714+
2715+
Returns
2716+
-------
2717+
dict[tuple[str,...], T]
2718+
The same values, but the keys have been split and converted to tuples.
2719+
2720+
Examples
2721+
--------
2722+
>>> _tuplize_tree({"a": 1}, separator='/')
2723+
{("a",): 1}
2724+
2725+
>>> _tuplize_tree({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/')
2726+
{("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3}
2727+
"""
2728+
2729+
return {tuple(k.split(separator)): v for k, v in data.items()}

0 commit comments

Comments
 (0)