Skip to content

Commit 089feef

Browse files
committed
sketch out from_flat for groups
1 parent 04f7922 commit 089feef

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

src/zarr/core/group.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from functools import partial
1111
from itertools import accumulate
1212
from pathlib import PurePosixPath
13-
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
13+
from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload
1414

1515
import numpy as np
1616
import numpy.typing as npt
@@ -426,6 +426,27 @@ class AsyncGroup:
426426
metadata: GroupMetadata
427427
store_path: StorePath
428428

429+
# TODO: make this correct and work
430+
# TODO: ensure that this can be bound properly to subclass of AsyncGroup
431+
@classmethod
432+
async def from_flat(
433+
cls,
434+
store: StoreLike,
435+
*,
436+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
437+
overwrite: bool = False) -> Self:
438+
439+
if overwrite:
440+
store_path = await make_store_path(store, mode='w')
441+
else:
442+
store_path = await make_store_path(store, mode='w-')
443+
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
444+
445+
nodes_created = {x.name: x async for x in create_hierarchy(
446+
store_path=store_path, nodes=nodes, semaphore=semaphore
447+
)}
448+
return nodes_created['']
449+
429450
@classmethod
430451
async def from_store(
431452
cls,
@@ -1269,15 +1290,6 @@ async def require_array(
12691290

12701291
return ds
12711292

1272-
async def _create_nodes(
1273-
self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]
1274-
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
1275-
"""
1276-
Create a set of arrays or groups rooted at this group.
1277-
"""
1278-
async for node in create_hierarchy(store_path=self.store_path, nodes=nodes):
1279-
yield node
1280-
12811293
async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
12821294
"""Update group attributes.
12831295
@@ -1731,6 +1743,17 @@ async def move(self, source: str, dest: str) -> None:
17311743
@dataclass(frozen=True)
17321744
class Group(SyncMixin):
17331745
_async_group: AsyncGroup
1746+
1747+
@classmethod
1748+
def from_flat(
1749+
cls,
1750+
store: StoreLike,
1751+
*,
1752+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
1753+
overwrite: bool = False) -> Group:
1754+
nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite))
1755+
# return the root node of the hierarchy
1756+
return nodes['']
17341757

17351758
@classmethod
17361759
def from_store(

0 commit comments

Comments
 (0)