Skip to content

Commit 661678f

Browse files
committed
use store + path instead of StorePath for hierarchy api
1 parent ed0d52a commit 661678f

File tree

2 files changed

+84
-69
lines changed

2 files changed

+84
-69
lines changed

src/zarr/core/group.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,9 @@ async def getitem(
694694
if self.metadata.consolidated_metadata is not None:
695695
return self._getitem_consolidated(store_path, key, prefix=self.name)
696696
elif self.metadata.zarr_format == 3:
697-
return await _read_node_v3(store_path=store_path)
697+
return await _read_node_v3(store=self.store, path=store_path.path)
698698
elif self.metadata.zarr_format == 2:
699-
return await _read_node_v2(store_path=store_path)
699+
return await _read_node_v2(store=self.store, path=store_path.path)
700700
else:
701701
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
702702

@@ -1423,7 +1423,8 @@ async def create_hierarchy(
14231423
"""
14241424
semaphore = asyncio.Semaphore(config.get("async.concurrency"))
14251425
async for node in create_hierarchy(
1426-
store_path=self.store_path,
1426+
store=self.store,
1427+
path=self.path,
14271428
nodes=nodes,
14281429
semaphore=semaphore,
14291430
overwrite=overwrite,
@@ -2838,7 +2839,8 @@ def array(
28382839

28392840
async def create_hierarchy(
28402841
*,
2841-
store_path: StorePath,
2842+
store: Store,
2843+
path: str,
28422844
nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata],
28432845
semaphore: asyncio.Semaphore | None = None,
28442846
overwrite: bool = False,
@@ -2887,7 +2889,7 @@ async def create_hierarchy(
28872889
# we allow creating empty hierarchies -- it's a no-op
28882890
if len(nodes_parsed) > 0:
28892891
if overwrite:
2890-
await store_path.delete_dir()
2892+
await store.delete_dir(path)
28912893
else:
28922894
# attempt to fetch all of the metadata described in hierarchy
28932895
# first figure out which zarr format we are dealing with
@@ -2901,8 +2903,8 @@ async def create_hierarchy(
29012903
zarr_format = sample.zarr_format
29022904
# TODO: this type hint is so long
29032905
func: (
2904-
Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]]
2905-
| Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]]
2906+
Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]]
2907+
| Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]]
29062908
)
29072909
if zarr_format == 3:
29082910
func = _read_metadata_v3
@@ -2911,7 +2913,7 @@ async def create_hierarchy(
29112913
else: # pragma: no cover
29122914
raise ValueError(f"Invalid zarr_format: {zarr_format}") # pragma: no cover
29132915

2914-
coros = (func(store_path=store_path / key) for key in nodes_parsed)
2916+
coros = (func(store=store, path=_join_paths([path, key])) for key in nodes_parsed)
29152917
extant_node_query = dict(
29162918
zip(
29172919
nodes_parsed.keys(),
@@ -2942,24 +2944,25 @@ async def create_hierarchy(
29422944

29432945
if isinstance(value, GroupMetadata):
29442946
if key not in implicit_group_names:
2945-
raise ContainsGroupError(store_path.store, key)
2947+
raise ContainsGroupError(store, key)
29462948
else:
29472949
# as there is already a group with this name, we should not create a new one
29482950
redundant_implicit_groups.append(key)
29492951
elif isinstance(value, ArrayV2Metadata | ArrayV3Metadata):
2950-
raise ContainsArrayError(store_path.store, key)
2952+
raise ContainsArrayError(store, key)
29512953

29522954
nodes_parsed = {
29532955
k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups
29542956
}
29552957

2956-
async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore):
2958+
async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore):
29572959
yield node
29582960

29592961

29602962
async def create_nodes(
29612963
*,
2962-
store_path: StorePath,
2964+
store: Store,
2965+
path: str,
29632966
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
29642967
semaphore: asyncio.Semaphore | None = None,
29652968
) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]:
@@ -2977,8 +2980,8 @@ async def create_nodes(
29772980
create_tasks: list[Coroutine[None, None, str]] = []
29782981
for key, value in nodes.items():
29792982
# transform the key, which is relative to a store_path.path, to a key in the store
2980-
write_prefix = _join_paths([store_path.path, key])
2981-
create_tasks.extend(_persist_metadata(store_path.store, write_prefix, value))
2983+
write_prefix = _join_paths([path, key])
2984+
create_tasks.extend(_persist_metadata(store, write_prefix, value))
29822985

29832986
created_object_keys = []
29842987
async with ctx:
@@ -2993,7 +2996,7 @@ async def create_nodes(
29932996
# the relative path of the object we just created -- we need this to track which metadata documents
29942997
# were written so that we can yield a complete v2 Array / Group class after both .zattrs
29952998
# and the metadata JSON was created.
2996-
object_path_relative = created_key.removeprefix(store_path.path).lstrip("/")
2999+
object_path_relative = created_key.removeprefix(path).lstrip("/")
29973000
created_object_keys.append(object_path_relative)
29983001

29993002
# get the node name from the object key
@@ -3008,7 +3011,7 @@ async def create_nodes(
30083011

30093012
if meta_out.zarr_format == 3:
30103013
# yes, it is silly that we relativize, then de-relativize this same path
3011-
node_store_path = store_path / node_name
3014+
node_store_path = StorePath(store=store, path=path) / node_name
30123015
if isinstance(meta_out, GroupMetadata):
30133016
yield AsyncGroup(metadata=meta_out, store_path=node_store_path)
30143017
else:
@@ -3027,7 +3030,7 @@ async def create_nodes(
30273030
meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys
30283031

30293032
if meta_done and attrs_done:
3030-
node_store_path = store_path / node_name
3033+
node_store_path = StorePath(store=store, path=path) / node_name
30313034
if isinstance(meta_out, GroupMetadata):
30323035
yield AsyncGroup(metadata=meta_out, store_path=node_store_path)
30333036
else:
@@ -3275,20 +3278,22 @@ async def _iter_members_deep(
32753278
yield key, node
32763279

32773280

3278-
async def _read_metadata_v3(store_path: StorePath) -> ArrayV3Metadata | GroupMetadata:
3281+
async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupMetadata:
32793282
"""
32803283
Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata
32813284
document stored at store_path.path / zarr.json. If no such document is found, raise a KeyError.
32823285
"""
3283-
zarr_json_bytes = await (store_path / ZARR_JSON).get()
3286+
zarr_json_bytes = await store.get(
3287+
_join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype()
3288+
)
32843289
if zarr_json_bytes is None:
3285-
raise KeyError(store_path.path)
3290+
raise KeyError(path)
32863291
else:
32873292
zarr_json = json.loads(zarr_json_bytes.to_bytes())
32883293
return _build_metadata_v3(zarr_json)
32893294

32903295

3291-
async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMetadata:
3296+
async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupMetadata:
32923297
"""
32933298
Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata
32943299
document stored at store_path.path / (.zgroup | .zarray). If no such document is found,
@@ -3297,9 +3302,9 @@ async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMet
32973302
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
32983303
# find an array
32993304
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3300-
(store_path / ZARRAY_JSON).get(),
3301-
(store_path / ZGROUP_JSON).get(),
3302-
(store_path / ZATTRS_JSON).get(),
3305+
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3306+
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3307+
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
33033308
)
33043309

33053310
if zattrs_bytes is None:
@@ -3315,7 +3320,7 @@ async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMet
33153320
else:
33163321
if zgroup_bytes is None:
33173322
# neither .zarray or .zgroup were found results in KeyError
3318-
raise KeyError(store_path.path)
3323+
raise KeyError(path)
33193324
else:
33203325
zmeta = json.loads(zgroup_bytes.to_bytes())
33213326

@@ -3353,11 +3358,15 @@ def _build_metadata_v2(
33533358

33543359

33553360
def _build_node_v3(
3356-
metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath
3361+
*,
3362+
store: Store,
3363+
path: str,
3364+
metadata: ArrayV3Metadata | GroupMetadata,
33573365
) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
33583366
"""
33593367
Take a metadata object and return a node (AsyncArray or AsyncGroup).
33603368
"""
3369+
store_path = StorePath(store=store, path=path)
33613370
match metadata:
33623371
case ArrayV3Metadata():
33633372
return AsyncArray(metadata, store_path=store_path)
@@ -3368,12 +3377,12 @@ def _build_node_v3(
33683377

33693378

33703379
def _build_node_v2(
3371-
metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath
3380+
*, store: Store, path: str, metadata: ArrayV2Metadata | GroupMetadata
33723381
) -> AsyncArray[ArrayV2Metadata] | AsyncGroup:
33733382
"""
33743383
Take a metadata object and return a node (AsyncArray or AsyncGroup).
33753384
"""
3376-
3385+
store_path = StorePath(store=store, path=path)
33773386
match metadata:
33783387
case ArrayV2Metadata():
33793388
return AsyncArray(metadata, store_path=store_path)
@@ -3383,33 +3392,33 @@ def _build_node_v2(
33833392
raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover
33843393

33853394

3386-
async def _read_node_v2(store_path: StorePath) -> AsyncArray[ArrayV2Metadata] | AsyncGroup:
3395+
async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup:
33873396
"""
33883397
Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath.
33893398
"""
3390-
metadata = await _read_metadata_v2(store_path=store_path)
3391-
return _build_node_v2(metadata=metadata, store_path=store_path)
3399+
metadata = await _read_metadata_v2(store=store, path=path)
3400+
return _build_node_v2(store=store, path=path, metadata=metadata)
33923401

33933402

3394-
async def _read_node_v3(store_path: StorePath) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
3403+
async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup:
33953404
"""
33963405
Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath.
33973406
"""
3398-
metadata = await _read_metadata_v3(store_path=store_path)
3399-
return _build_node_v3(metadata=metadata, store_path=store_path)
3407+
metadata = await _read_metadata_v3(store=store, path=path)
3408+
return _build_node_v3(store=store, path=path, metadata=metadata)
34003409

34013410

34023411
async def _read_node(
3403-
store_path: StorePath, zarr_format: ZarrFormat
3412+
store: Store, path: str, zarr_format: ZarrFormat
34043413
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup:
34053414
"""
34063415
Read an AsyncArray or AsyncGroup from a location defined by a StorePath.
34073416
"""
34083417
match zarr_format:
34093418
case 2:
3410-
return await _read_node_v2(store_path=store_path)
3419+
return await _read_node_v2(store=store, path=path)
34113420
case 3:
3412-
return await _read_node_v3(store_path=store_path)
3421+
return await _read_node_v3(store=store, path=path)
34133422
case _: # pragma: no cover
34143423
raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover
34153424

@@ -3449,8 +3458,9 @@ def _persist_metadata(
34493458

34503459

34513460
async def _create_rooted_hierarchy(
3452-
store_path: StorePath,
34533461
*,
3462+
store: Store,
3463+
path: str,
34543464
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
34553465
overwrite: bool = False,
34563466
) -> AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
@@ -3476,17 +3486,18 @@ async def _create_rooted_hierarchy(
34763486
nodes_created = {
34773487
x.path: x
34783488
async for x in create_hierarchy(
3479-
store_path=store_path, nodes=nodes, semaphore=semaphore, overwrite=overwrite
3489+
store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite
34803490
)
34813491
}
34823492
# the names of the created nodes will be relative to the store_path instance
3483-
root_relative_to_store_path = _join_paths([store_path.path, root])
3493+
root_relative_to_store_path = _join_paths([path, root])
34843494
return nodes_created[root_relative_to_store_path]
34853495

34863496

34873497
def _create_rooted_hierarchy_sync(
3488-
store_path: StorePath,
34893498
*,
3499+
store: Store,
3500+
path: str,
34903501
nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
34913502
overwrite: bool = False,
34923503
) -> Group | Array:
@@ -3495,7 +3506,7 @@ def _create_rooted_hierarchy_sync(
34953506
``_create_rooted_hierarchy`` and waits for the result.
34963507
"""
34973508
async_node = sync(
3498-
_create_rooted_hierarchy(store_path=store_path, nodes=nodes, overwrite=overwrite)
3509+
_create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite)
34993510
)
35003511
if isinstance(async_node, AsyncGroup):
35013512
return Group(async_node)

0 commit comments

Comments
 (0)