Skip to content

Commit b099fba

Browse files
committed
working _from_flat
1 parent 2fb9083 commit b099fba

File tree

2 files changed

+143
-95
lines changed

2 files changed

+143
-95
lines changed

src/zarr/core/group.py

Lines changed: 117 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import (
1414
TYPE_CHECKING,
1515
Literal,
16-
Self,
1716
TypeVar,
1817
assert_never,
1918
cast,
@@ -437,35 +436,6 @@ class AsyncGroup:
437436

438437
# TODO: make this correct and work
439438
# TODO: ensure that this can be bound properly to subclass of AsyncGroup
440-
@classmethod
441-
async def from_flat(
442-
cls,
443-
store: StoreLike,
444-
*,
445-
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
446-
overwrite: bool = False,
447-
) -> Self:
448-
if not _is_rooted(nodes):
449-
msg = (
450-
"The input does not specify a root node. ",
451-
"This function can only create hierarchies that contain a root node, which is ",
452-
"defined as a group that is ancestral to all the other arrays and ",
453-
"groups in the hierarchy.",
454-
)
455-
raise ValueError(msg)
456-
457-
if overwrite:
458-
store_path = await make_store_path(store, mode="w")
459-
else:
460-
store_path = await make_store_path(store, mode="w-")
461-
462-
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
463-
464-
nodes_created = {
465-
x.name: x
466-
async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore)
467-
}
468-
# TODO: make this work
469439

470440
@classmethod
471441
async def from_store(
@@ -1764,18 +1734,6 @@ async def move(self, source: str, dest: str) -> None:
17641734
class Group(SyncMixin):
17651735
_async_group: AsyncGroup
17661736

1767-
@classmethod
1768-
def from_flat(
1769-
cls,
1770-
store: StoreLike,
1771-
*,
1772-
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
1773-
overwrite: bool = False,
1774-
) -> Group:
1775-
nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite))
1776-
# return the root node of the hierarchy
1777-
return nodes[""]
1778-
17791737
@classmethod
17801738
def from_store(
17811739
cls,
@@ -2889,23 +2847,6 @@ def array(
28892847
)
28902848

28912849

2892-
async def _save_metadata(
2893-
node: AsyncArray[Any] | AsyncGroup,
2894-
overwrite: bool,
2895-
) -> AsyncArray[Any] | AsyncGroup:
2896-
"""
2897-
Save the metadata for an array or group, and return the array or group
2898-
"""
2899-
match node:
2900-
case AsyncArray():
2901-
await node._save_metadata(node.metadata, ensure_parents=False)
2902-
case AsyncGroup():
2903-
await node._save_metadata(ensure_parents=False)
2904-
case _:
2905-
raise ValueError(f"Unexpected node type {type(node)}")
2906-
return node
2907-
2908-
29092850
async def create_hierarchy(
29102851
*,
29112852
store_path: StorePath,
@@ -2962,22 +2903,17 @@ async def create_nodes(
29622903
ctx = semaphore
29632904

29642905
create_tasks: list[Coroutine[None, None, str]] = []
2965-
29662906
for key, value in nodes.items():
2967-
create_tasks.extend(
2968-
_prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value)
2969-
)
2970-
if store_path.path == "":
2971-
root = "/"
2972-
else:
2973-
root = store_path.path
2907+
write_key = str(PurePosixPath(store_path.path) / key)
2908+
create_tasks.extend(_persist_metadata(store_path.store, write_key, value))
2909+
29742910
created_keys = []
29752911
async with ctx:
29762912
for coro in asyncio.as_completed(create_tasks):
29772913
created_key = await coro
2978-
relative_path = PurePosixPath(created_key).relative_to(root)
2914+
relative_path = PurePosixPath(created_key).relative_to(store_path.path)
29792915
created_keys.append(str(relative_path))
2980-
# convert /foo/bar/baz/.zattrs to bar/baz
2916+
# convert foo/bar/baz/.zattrs to bar/baz
29812917
node_name = str(relative_path.parent)
29822918
meta_out = nodes[node_name]
29832919

@@ -3009,13 +2945,17 @@ async def create_nodes(
30092945
T = TypeVar("T")
30102946

30112947

3012-
def _is_rooted(data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]) -> bool:
2948+
def _get_roots(
2949+
data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
2950+
) -> tuple[str, ...]:
30132951
"""
3014-
Check if the data describes a hierarchy that's rooted, which means there is a single node with
3015-
the least number of components in its key
2952+
Return the keys of the root(s) of the hierarchy
30162953
"""
3017-
# a dict
3018-
return False
2954+
keys_split = sorted((key.split("/") for key in data), key=len)
2955+
groups: defaultdict[int, list[str]] = defaultdict(list)
2956+
for key_split in keys_split:
2957+
groups[len(key_split)].append("/".join(key_split))
2958+
return tuple(groups[min(groups.keys())])
30192959

30202960

30212961
def _parse_hierarchy_dict(
@@ -3066,9 +3006,9 @@ def _parse_hierarchy_dict(
30663006
# Iterate over the intermediate path components
30673007
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
30683008
for subpath in subpaths:
3069-
# If a component is not already in the output dict, add it
3009+
# If a component is not already in the output dict, add an implicit group marker
30703010
if subpath not in out:
3071-
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
3011+
out[subpath] = _ImplicitGroupMetadata(zarr_format=v.zarr_format)
30723012
else:
30733013
if not isinstance(out[subpath], GroupMetadata):
30743014
msg = (
@@ -3268,22 +3208,115 @@ def _build_node_v2(
32683208
raise ValueError(f"Unexpected metadata type: {type(metadata)}")
32693209

32703210

3271-
async def _set_return_key(store: Store, key: str, value: Buffer) -> str:
3211+
async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str:
32723212
"""
3273-
Store.set, but the key and the value are returned.
3274-
Useful when saving metadata via asyncio.as_completed, because
3275-
we need to know which key was saved.
3213+
Either write a value to storage at the given key, or ensure that there is already a value in
3214+
storage at the given key. The key is returned in either case.
3215+
Useful when saving values via routines that return results in execution order,
3216+
like asyncio.as_completed, because in this case we need to know which key was saved in order
3217+
to yield the right object to the caller.
3218+
3219+
Parameters
3220+
----------
3221+
store : Store
3222+
The store to save the value to.
3223+
key : str
3224+
The key to save the value to.
3225+
value : Buffer
3226+
The value to save.
3227+
replace : bool
3228+
If True, then the value will be written even if a value associated with the key
3229+
already exists in storage. If False, an existing value will not be overwritten.
32763230
"""
3277-
await store.set(key, value)
3231+
if replace:
3232+
await store.set(key, value)
3233+
else:
3234+
await store.set_if_not_exists(key, value)
32783235
return key
32793236

32803237

3281-
def _prepare_save_metadata(
3238+
def _persist_metadata(
32823239
store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata
32833240
) -> tuple[Coroutine[None, None, str], ...]:
32843241
"""
3285-
Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited.
3242+
Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited.
3243+
If ``metadata`` is an instance of ``_ImplicitGroupMetadata``, then _set_return_key will be invoked with
3244+
``replace=False``, which defers to a pre-existing metadata document in storage if one exists. Otherwise, existing values will be overwritten.
32863245
"""
32873246

32883247
to_save = metadata.to_buffer_dict(default_buffer_prototype())
3289-
return tuple(_set_return_key(store, f"{path}/{key}", value) for key, value in to_save.items())
3248+
if isinstance(metadata, _ImplicitGroupMetadata):
3249+
replace = False
3250+
else:
3251+
replace = True
3252+
# TODO: should this function be a generator that yields values instead of eagerly returning a tuple?
3253+
return tuple(
3254+
_set_return_key(store=store, key=f"{path}/{key}", value=value, replace=replace)
3255+
for key, value in to_save.items()
3256+
)
3257+
3258+
3259+
class _ImplicitGroupMetadata(GroupMetadata):
3260+
"""
3261+
This class represents the metadata document of a group that should created at some
3262+
location in storage if and only if there is not already a group at that location.
3263+
3264+
This class is used to fill group-shaped "holes" in a dict specification of a Zarr hierarchy.
3265+
3266+
When attempting to write this class to disk, the writer should first check if a Zarr group
3267+
already exists at the desired location. If such a group does exist, the writer should do nothing.
3268+
If not, the writer should write this metadata document to storage.
3269+
3270+
"""
3271+
3272+
def __init__(
3273+
self,
3274+
attributes: dict[str, Any] | None = None,
3275+
zarr_format: ZarrFormat = 3,
3276+
consolidated_metadata: ConsolidatedMetadata | None = None,
3277+
) -> None:
3278+
if attributes is not None:
3279+
raise ValueError("attributes must be None for implicit groups")
3280+
3281+
if consolidated_metadata is not None:
3282+
raise ValueError("consolidated_metadata must be None for implicit groups")
3283+
3284+
super().__init__(attributes, zarr_format, consolidated_metadata)
3285+
3286+
3287+
async def _from_flat(
3288+
store: StoreLike,
3289+
*,
3290+
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
3291+
overwrite: bool = False,
3292+
) -> AsyncGroup:
3293+
"""
3294+
Create an ``AsyncGroup`` from a store + a dict of nodes.
3295+
"""
3296+
roots = _get_roots(nodes)
3297+
if len(roots) != 1:
3298+
msg = (
3299+
"The input does not specify a root node. "
3300+
"This function can only create hierarchies that contain a root node, which is "
3301+
"defined as a group that is ancestral to all the other arrays and "
3302+
"groups in the hierarchy."
3303+
)
3304+
raise ValueError(msg)
3305+
else:
3306+
root = roots[0]
3307+
3308+
if overwrite:
3309+
store_path = await make_store_path(store, mode="w")
3310+
else:
3311+
store_path = await make_store_path(store, mode="w-")
3312+
3313+
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
3314+
3315+
nodes_created = {
3316+
x.path: x
3317+
async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore)
3318+
}
3319+
root_group = nodes_created[root]
3320+
if not isinstance(root_group, AsyncGroup):
3321+
raise TypeError("Invalid root node returned from create_hierarchy.")
3322+
return root_group

tests/test_group.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from zarr.abc.store import Store
2222
from zarr.core._info import GroupInfo
2323
from zarr.core.buffer import default_buffer_prototype
24-
from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_hierarchy, create_nodes
24+
from zarr.core.group import (
25+
ConsolidatedMetadata,
26+
GroupMetadata,
27+
_from_flat,
28+
create_hierarchy,
29+
create_nodes,
30+
)
2531
from zarr.core.sync import _collect_aiterator, sync
2632
from zarr.errors import ContainsArrayError, ContainsGroupError
2733
from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore
@@ -1578,20 +1584,29 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store):
15781584

15791585

15801586
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1581-
async def test_group_from_flat(store: Store, zarr_format):
1587+
@pytest.mark.parametrize("zarr_format", [2, 3])
1588+
@pytest.mark.parametrize("root_key", ["", "a", "a/b"])
1589+
async def test_group_from_flat(store: Store, zarr_format, root_key: str):
15821590
"""
15831591
Test that the AsyncGroup.from_flat method creates a zarr group in one shot.
15841592
"""
1585-
hierarchy_spec = {
1586-
"a": GroupMetadata(zarr_format=zarr_format),
1587-
"a/b": GroupMetadata(zarr_format=zarr_format),
1588-
"a/b/c": GroupMetadata(zarr_format=zarr_format),
1593+
root_key = "a"
1594+
root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})}
1595+
members_expected_meta = {
1596+
f"{root_key}/b": GroupMetadata(
1597+
zarr_format=zarr_format, attributes={"path": f"{root_key}/b"}
1598+
),
1599+
f"{root_key}/b/c": GroupMetadata(
1600+
zarr_format=zarr_format, attributes={"path": f"{root_key}/b/c"}
1601+
),
15891602
}
1590-
g = await AsyncGroup.from_flat(store, nodes=hierarchy_spec)
1591-
assert g.members() == [
1592-
("b", GroupMetadata(zarr_format=zarr_format)),
1593-
("b/c", GroupMetadata(zarr_format=zarr_format)),
1594-
]
1603+
g = await _from_flat(store, nodes=root_meta | members_expected_meta)
1604+
members = await _collect_aiterator(g.members(max_depth=None))
1605+
members_observed_meta = {k: v.metadata for k, v in members}
1606+
members_expected_meta_relative = {
1607+
str(PurePosixPath(k).relative_to(root_key)): v for k, v in members_expected_meta.items()
1608+
}
1609+
assert members_observed_meta == members_expected_meta_relative
15951610

15961611

15971612
@pytest.mark.parametrize("store", ["memory"], indirect=True)

0 commit comments

Comments
 (0)