|
13 | 13 | from typing import ( |
14 | 14 | TYPE_CHECKING, |
15 | 15 | Literal, |
16 | | - Self, |
17 | 16 | TypeVar, |
18 | 17 | assert_never, |
19 | 18 | cast, |
@@ -437,35 +436,6 @@ class AsyncGroup: |
437 | 436 |
|
438 | 437 | # TODO: make this correct and work |
439 | 438 | # TODO: ensure that this can be bound properly to subclass of AsyncGroup |
440 | | - @classmethod |
441 | | - async def from_flat( |
442 | | - cls, |
443 | | - store: StoreLike, |
444 | | - *, |
445 | | - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], |
446 | | - overwrite: bool = False, |
447 | | - ) -> Self: |
448 | | - if not _is_rooted(nodes): |
449 | | - msg = ( |
450 | | - "The input does not specify a root node. ", |
451 | | - "This function can only create hierarchies that contain a root node, which is ", |
452 | | - "defined as a group that is ancestral to all the other arrays and ", |
453 | | - "groups in the hierarchy.", |
454 | | - ) |
455 | | - raise ValueError(msg) |
456 | | - |
457 | | - if overwrite: |
458 | | - store_path = await make_store_path(store, mode="w") |
459 | | - else: |
460 | | - store_path = await make_store_path(store, mode="w-") |
461 | | - |
462 | | - semaphore = asyncio.Semaphore(config.get("async.concurrency")) |
463 | | - |
464 | | - nodes_created = { |
465 | | - x.name: x |
466 | | - async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) |
467 | | - } |
468 | | - # TODO: make this work |
469 | 439 |
|
470 | 440 | @classmethod |
471 | 441 | async def from_store( |
@@ -1764,18 +1734,6 @@ async def move(self, source: str, dest: str) -> None: |
1764 | 1734 | class Group(SyncMixin): |
1765 | 1735 | _async_group: AsyncGroup |
1766 | 1736 |
|
1767 | | - @classmethod |
1768 | | - def from_flat( |
1769 | | - cls, |
1770 | | - store: StoreLike, |
1771 | | - *, |
1772 | | - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], |
1773 | | - overwrite: bool = False, |
1774 | | - ) -> Group: |
1775 | | - nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite)) |
1776 | | - # return the root node of the hierarchy |
1777 | | - return nodes[""] |
1778 | | - |
1779 | 1737 | @classmethod |
1780 | 1738 | def from_store( |
1781 | 1739 | cls, |
@@ -2889,23 +2847,6 @@ def array( |
2889 | 2847 | ) |
2890 | 2848 |
|
2891 | 2849 |
|
2892 | | -async def _save_metadata( |
2893 | | - node: AsyncArray[Any] | AsyncGroup, |
2894 | | - overwrite: bool, |
2895 | | -) -> AsyncArray[Any] | AsyncGroup: |
2896 | | - """ |
2897 | | - Save the metadata for an array or group, and return the array or group |
2898 | | - """ |
2899 | | - match node: |
2900 | | - case AsyncArray(): |
2901 | | - await node._save_metadata(node.metadata, ensure_parents=False) |
2902 | | - case AsyncGroup(): |
2903 | | - await node._save_metadata(ensure_parents=False) |
2904 | | - case _: |
2905 | | - raise ValueError(f"Unexpected node type {type(node)}") |
2906 | | - return node |
2907 | | - |
2908 | | - |
2909 | 2850 | async def create_hierarchy( |
2910 | 2851 | *, |
2911 | 2852 | store_path: StorePath, |
@@ -2962,22 +2903,17 @@ async def create_nodes( |
2962 | 2903 | ctx = semaphore |
2963 | 2904 |
|
2964 | 2905 | create_tasks: list[Coroutine[None, None, str]] = [] |
2965 | | - |
2966 | 2906 | for key, value in nodes.items(): |
2967 | | - create_tasks.extend( |
2968 | | - _prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value) |
2969 | | - ) |
2970 | | - if store_path.path == "": |
2971 | | - root = "/" |
2972 | | - else: |
2973 | | - root = store_path.path |
| 2907 | + write_key = str(PurePosixPath(store_path.path) / key) |
| 2908 | + create_tasks.extend(_persist_metadata(store_path.store, write_key, value)) |
| 2909 | + |
2974 | 2910 | created_keys = [] |
2975 | 2911 | async with ctx: |
2976 | 2912 | for coro in asyncio.as_completed(create_tasks): |
2977 | 2913 | created_key = await coro |
2978 | | - relative_path = PurePosixPath(created_key).relative_to(root) |
| 2914 | + relative_path = PurePosixPath(created_key).relative_to(store_path.path) |
2979 | 2915 | created_keys.append(str(relative_path)) |
2980 | | - # convert /foo/bar/baz/.zattrs to bar/baz |
| 2916 | + # convert foo/bar/baz/.zattrs to bar/baz |
2981 | 2917 | node_name = str(relative_path.parent) |
2982 | 2918 | meta_out = nodes[node_name] |
2983 | 2919 |
|
@@ -3009,13 +2945,17 @@ async def create_nodes( |
3009 | 2945 | T = TypeVar("T") |
3010 | 2946 |
|
3011 | 2947 |
|
3012 | | -def _is_rooted(data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]) -> bool: |
| 2948 | +def _get_roots( |
| 2949 | + data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], |
| 2950 | +) -> tuple[str, ...]: |
3013 | 2951 | """ |
3014 | | - Check if the data describes a hierarchy that's rooted, which means there is a single node with |
3015 | | - the least number of components in its key |
| 2952 | + Return the keys of the root(s) of the hierarchy |
3016 | 2953 | """ |
3017 | | - # a dict |
3018 | | - return False |
| 2954 | + keys_split = sorted((key.split("/") for key in data), key=len) |
| 2955 | + groups: defaultdict[int, list[str]] = defaultdict(list) |
| 2956 | + for key_split in keys_split: |
| 2957 | + groups[len(key_split)].append("/".join(key_split)) |
| 2958 | + return tuple(groups[min(groups.keys())]) |
3019 | 2959 |
|
3020 | 2960 |
|
3021 | 2961 | def _parse_hierarchy_dict( |
@@ -3066,9 +3006,9 @@ def _parse_hierarchy_dict( |
3066 | 3006 | # Iterate over the intermediate path components |
3067 | 3007 | *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") |
3068 | 3008 | for subpath in subpaths: |
3069 | | - # If a component is not already in the output dict, add it |
| 3009 | + # If a component is not already in the output dict, add an implicit group marker |
3070 | 3010 | if subpath not in out: |
3071 | | - out[subpath] = GroupMetadata(zarr_format=v.zarr_format) |
| 3011 | + out[subpath] = _ImplicitGroupMetadata(zarr_format=v.zarr_format) |
3072 | 3012 | else: |
3073 | 3013 | if not isinstance(out[subpath], GroupMetadata): |
3074 | 3014 | msg = ( |
@@ -3268,22 +3208,115 @@ def _build_node_v2( |
3268 | 3208 | raise ValueError(f"Unexpected metadata type: {type(metadata)}") |
3269 | 3209 |
|
3270 | 3210 |
|
3271 | | -async def _set_return_key(store: Store, key: str, value: Buffer) -> str: |
| 3211 | +async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: |
3272 | 3212 | """ |
3273 | | - Store.set, but the key and the value are returned. |
3274 | | - Useful when saving metadata via asyncio.as_completed, because |
3275 | | - we need to know which key was saved. |
| 3213 | + Either write a value to storage at the given key, or ensure that there is already a value in |
| 3214 | + storage at the given key. The key is returned in either case. |
| 3215 | + Useful when saving values via routines that return results in execution order, |
| 3216 | + like asyncio.as_completed, because in this case we need to know which key was saved in order |
| 3217 | + to yield the right object to the caller. |
| 3218 | +
|
| 3219 | + Parameters |
| 3220 | + ---------- |
| 3221 | + store : Store |
| 3222 | + The store to save the value to. |
| 3223 | + key : str |
| 3224 | + The key to save the value to. |
| 3225 | + value : Buffer |
| 3226 | + The value to save. |
| 3227 | + replace : bool |
| 3228 | + If True, then the value will be written even if a value associated with the key |
| 3229 | + already exists in storage. If False, an existing value will not be overwritten. |
3276 | 3230 | """ |
3277 | | - await store.set(key, value) |
| 3231 | + if replace: |
| 3232 | + await store.set(key, value) |
| 3233 | + else: |
| 3234 | + await store.set_if_not_exists(key, value) |
3278 | 3235 | return key |
3279 | 3236 |
|
3280 | 3237 |
|
3281 | | -def _prepare_save_metadata( |
| 3238 | +def _persist_metadata( |
3282 | 3239 | store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata |
3283 | 3240 | ) -> tuple[Coroutine[None, None, str], ...]: |
3284 | 3241 | """ |
3285 | | - Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited. |
| 3242 | + Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. |
| 3243 | + If ``metadata`` is an instance of ``_ImplicitGroupMetadata``, then _set_return_key will be invoked with |
| 3244 | + ``replace=False``, which defers to a pre-existing metadata document in storage if one exists. Otherwise, existing values will be overwritten. |
3286 | 3245 | """ |
3287 | 3246 |
|
3288 | 3247 | to_save = metadata.to_buffer_dict(default_buffer_prototype()) |
3289 | | - return tuple(_set_return_key(store, f"{path}/{key}", value) for key, value in to_save.items()) |
| 3248 | + if isinstance(metadata, _ImplicitGroupMetadata): |
| 3249 | + replace = False |
| 3250 | + else: |
| 3251 | + replace = True |
| 3252 | + # TODO: should this function be a generator that yields values instead of eagerly returning a tuple? |
| 3253 | + return tuple( |
| 3254 | + _set_return_key(store=store, key=f"{path}/{key}", value=value, replace=replace) |
| 3255 | + for key, value in to_save.items() |
| 3256 | + ) |
| 3257 | + |
| 3258 | + |
| 3259 | +class _ImplicitGroupMetadata(GroupMetadata): |
| 3260 | + """ |
| 3261 | + This class represents the metadata document of a group that should created at some |
| 3262 | + location in storage if and only if there is not already a group at that location. |
| 3263 | +
|
| 3264 | + This class is used to fill group-shaped "holes" in a dict specification of a Zarr hierarchy. |
| 3265 | +
|
| 3266 | + When attempting to write this class to disk, the writer should first check if a Zarr group |
| 3267 | + already exists at the desired location. If such a group does exist, the writer should do nothing. |
| 3268 | + If not, the writer should write this metadata document to storage. |
| 3269 | +
|
| 3270 | + """ |
| 3271 | + |
| 3272 | + def __init__( |
| 3273 | + self, |
| 3274 | + attributes: dict[str, Any] | None = None, |
| 3275 | + zarr_format: ZarrFormat = 3, |
| 3276 | + consolidated_metadata: ConsolidatedMetadata | None = None, |
| 3277 | + ) -> None: |
| 3278 | + if attributes is not None: |
| 3279 | + raise ValueError("attributes must be None for implicit groups") |
| 3280 | + |
| 3281 | + if consolidated_metadata is not None: |
| 3282 | + raise ValueError("consolidated_metadata must be None for implicit groups") |
| 3283 | + |
| 3284 | + super().__init__(attributes, zarr_format, consolidated_metadata) |
| 3285 | + |
| 3286 | + |
| 3287 | +async def _from_flat( |
| 3288 | + store: StoreLike, |
| 3289 | + *, |
| 3290 | + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], |
| 3291 | + overwrite: bool = False, |
| 3292 | +) -> AsyncGroup: |
| 3293 | + """ |
| 3294 | + Create an ``AsyncGroup`` from a store + a dict of nodes. |
| 3295 | + """ |
| 3296 | + roots = _get_roots(nodes) |
| 3297 | + if len(roots) != 1: |
| 3298 | + msg = ( |
| 3299 | + "The input does not specify a root node. " |
| 3300 | + "This function can only create hierarchies that contain a root node, which is " |
| 3301 | + "defined as a group that is ancestral to all the other arrays and " |
| 3302 | + "groups in the hierarchy." |
| 3303 | + ) |
| 3304 | + raise ValueError(msg) |
| 3305 | + else: |
| 3306 | + root = roots[0] |
| 3307 | + |
| 3308 | + if overwrite: |
| 3309 | + store_path = await make_store_path(store, mode="w") |
| 3310 | + else: |
| 3311 | + store_path = await make_store_path(store, mode="w-") |
| 3312 | + |
| 3313 | + semaphore = asyncio.Semaphore(config.get("async.concurrency")) |
| 3314 | + |
| 3315 | + nodes_created = { |
| 3316 | + x.path: x |
| 3317 | + async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) |
| 3318 | + } |
| 3319 | + root_group = nodes_created[root] |
| 3320 | + if not isinstance(root_group, AsyncGroup): |
| 3321 | + raise TypeError("Invalid root node returned from create_hierarchy.") |
| 3322 | + return root_group |
0 commit comments