Skip to content

Commit 0095ee6

Browse files
Feature/recursive members (#2118)
* Fixed MemoryStore.list_dir Ensures that nested children are listed properly. * fixup s3 * recursive Group.members This PR adds a recursive=True flag to Group.members, for recursively listing the members of some hierarhcy. This is useful for Consolidated Metadata, which needs to recursively inspect children. IMO, it's useful (and simple) enough to include in the public API. * trigger ci * fixed datetime serialization * fixup * fixed invalid escape sequence * fixup * max_depth * max_depth=None
1 parent 61683be commit 0095ee6

File tree

6 files changed

+191
-22
lines changed

6 files changed

+191
-22
lines changed

src/zarr/core/group.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -424,21 +424,59 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
424424
def __repr__(self) -> str:
425425
return f"<AsyncGroup {self.store_path}>"
426426

427-
async def nmembers(self) -> int:
427+
async def nmembers(
428+
self,
429+
max_depth: int | None = 0,
430+
) -> int:
431+
"""
432+
Count the number of members in this group.
433+
434+
Parameters
435+
----------
436+
max_depth : int, default 0
437+
The maximum number of levels of the hierarchy to include. By
438+
default, (``max_depth=0``) only immediate children are included. Set
439+
``max_depth=None`` to include all nodes, and some positive integer
440+
to consider children within that many levels of the root Group.
441+
442+
Returns
443+
-------
444+
count : int
445+
"""
428446
# TODO: consider using aioitertools.builtins.sum for this
429447
# return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0)
430448
n = 0
431-
async for _ in self.members():
449+
async for _ in self.members(max_depth=max_depth):
432450
n += 1
433451
return n
434452

435-
async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
453+
async def members(
454+
self,
455+
max_depth: int | None = 0,
456+
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
436457
"""
437458
Returns an AsyncGenerator over the arrays and groups contained in this group.
438459
This method requires that `store_path.store` supports directory listing.
439460
440461
The results are not guaranteed to be ordered.
462+
463+
Parameters
464+
----------
465+
max_depth : int, default 0
466+
The maximum number of levels of the hierarchy to include. By
467+
default, (``max_depth=0``) only immediate children are included. Set
468+
``max_depth=None`` to include all nodes, and some positive integer
469+
to consider children within that many levels of the root Group.
470+
441471
"""
472+
if max_depth is not None and max_depth < 0:
473+
raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead")
474+
async for item in self._members(max_depth=max_depth, current_depth=0):
475+
yield item
476+
477+
async def _members(
478+
self, max_depth: int | None, current_depth: int
479+
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
442480
if not self.store_path.store.supports_listing:
443481
msg = (
444482
f"The store associated with this group ({type(self.store_path.store)}) "
@@ -456,7 +494,21 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
456494
if key in _skip_keys:
457495
continue
458496
try:
459-
yield (key, await self.getitem(key))
497+
obj = await self.getitem(key)
498+
yield (key, obj)
499+
500+
if (
501+
((max_depth is None) or (current_depth < max_depth))
502+
and hasattr(obj.metadata, "node_type")
503+
and obj.metadata.node_type == "group"
504+
):
505+
# the assert is just for mypy to know that `obj.metadata.node_type`
506+
# implies an AsyncGroup, not an AsyncArray
507+
assert isinstance(obj, AsyncGroup)
508+
async for child_key, val in obj._members(
509+
max_depth=max_depth, current_depth=current_depth + 1
510+
):
511+
yield "/".join([key, child_key]), val
460512
except KeyError:
461513
# keyerror is raised when `key` names an object (in the object storage sense),
462514
# as opposed to a prefix, in the store under the prefix associated with this group
@@ -628,17 +680,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
628680
self._sync(self._async_group.update_attributes(new_attributes))
629681
return self
630682

631-
@property
632-
def nmembers(self) -> int:
633-
return self._sync(self._async_group.nmembers())
683+
def nmembers(self, max_depth: int | None = 0) -> int:
684+
return self._sync(self._async_group.nmembers(max_depth=max_depth))
634685

635-
@property
636-
def members(self) -> tuple[tuple[str, Array | Group], ...]:
686+
def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], ...]:
637687
"""
638688
Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group)
639689
pairs
640690
"""
641-
_members = self._sync_iter(self._async_group.members())
691+
_members = self._sync_iter(self._async_group.members(max_depth=max_depth))
642692

643693
result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members))
644694
return result

src/zarr/core/metadata.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,21 @@ def _json_convert(o: Any) -> Any:
256256
if isinstance(o, np.dtype):
257257
return str(o)
258258
if np.isscalar(o):
259-
# convert numpy scalar to python type, and pass
260-
# python types through
261-
out = getattr(o, "item", lambda: o)()
262-
if isinstance(out, complex):
263-
# python complex types are not JSON serializable, so we use the
264-
# serialization defined in the zarr v3 spec
265-
return [out.real, out.imag]
259+
out: Any
260+
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
261+
# https://github.com/zarr-developers/zarr-python/issues/2119
262+
# `.item()` on a datetime type might or might not return an
263+
# integer, depending on the value.
264+
# Explicitly cast to an int first, and then grab .item()
265+
out = o.view("i8").item()
266+
else:
267+
# convert numpy scalar to python type, and pass
268+
# python types through
269+
out = getattr(o, "item", lambda: o)()
270+
if isinstance(out, complex):
271+
# python complex types are not JSON serializable, so we use the
272+
# serialization defined in the zarr v3 spec
273+
return [out.real, out.imag]
266274
return out
267275
if isinstance(o, Enum):
268276
return o.name

src/zarr/testing/strategies.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
import hypothesis.extra.numpy as npst
@@ -101,7 +102,14 @@ def arrays(
101102
root = Group.create(store)
102103
fill_value_args: tuple[Any, ...] = tuple()
103104
if nparray.dtype.kind == "M":
104-
fill_value_args = ("ns",)
105+
m = re.search(r"\[(.+)\]", nparray.dtype.str)
106+
if not m:
107+
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")
108+
109+
fill_value_args = (
110+
# e.g. ns, D
111+
m.groups()[0],
112+
)
105113

106114
a = root.create_array(
107115
array_path,

tests/v3/test_group.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
8888
members_expected["subgroup"] = group.create_group("subgroup")
8989
# make a sub-sub-subgroup, to ensure that the children calculation doesn't go
9090
# too deep in the hierarchy
91-
_ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
91+
subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
92+
subsubsubgroup = subsubgroup.create_group("subsubsubgroup") # type: ignore
9293

9394
members_expected["subarray"] = group.create_array(
9495
"subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True
@@ -101,10 +102,25 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
101102
# this creates a directory with a random key in it
102103
# this should not show up as a member
103104
sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000")))
104-
members_observed = group.members
105+
members_observed = group.members()
105106
# members are not guaranteed to be ordered, so sort before comparing
106107
assert sorted(dict(members_observed)) == sorted(members_expected)
107108

109+
# partial
110+
members_observed = group.members(max_depth=1)
111+
members_expected["subgroup/subsubgroup"] = subsubgroup
112+
# members are not guaranteed to be ordered, so sort before comparing
113+
assert sorted(dict(members_observed)) == sorted(members_expected)
114+
115+
# total
116+
members_observed = group.members(max_depth=None)
117+
members_expected["subgroup/subsubgroup/subsubsubgroup"] = subsubsubgroup
118+
# members are not guaranteed to be ordered, so sort before comparing
119+
assert sorted(dict(members_observed)) == sorted(members_expected)
120+
121+
with pytest.raises(ValueError, match="max_depth"):
122+
members_observed = group.members(max_depth=-1)
123+
108124

109125
def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None:
110126
"""
@@ -349,7 +365,8 @@ def test_group_create_array(
349365
if method == "create_array":
350366
array = group.create_array(name="array", shape=shape, dtype=dtype, data=data)
351367
elif method == "array":
352-
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
368+
with pytest.warns(DeprecationWarning):
369+
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
353370
else:
354371
raise AssertionError
355372

@@ -358,7 +375,7 @@ def test_group_create_array(
358375
with pytest.raises(ContainsArrayError):
359376
group.create_array(name="array", shape=shape, dtype=dtype, data=data)
360377
elif method == "array":
361-
with pytest.raises(ContainsArrayError):
378+
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
362379
group.array(name="array", shape=shape, dtype=dtype, data=data)
363380
assert array.shape == shape
364381
assert array.dtype == np.dtype(dtype)
@@ -653,3 +670,56 @@ async def test_asyncgroup_update_attributes(
653670

654671
agroup_new_attributes = await agroup.update_attributes(attributes_new)
655672
assert agroup_new_attributes.attrs == attributes_new
673+
674+
675+
async def test_group_members_async(store: LocalStore | MemoryStore):
676+
group = AsyncGroup(
677+
GroupMetadata(),
678+
store_path=StorePath(store=store, path="root"),
679+
)
680+
a0 = await group.create_array("a0", (1,))
681+
g0 = await group.create_group("g0")
682+
a1 = await g0.create_array("a1", (1,))
683+
g1 = await g0.create_group("g1")
684+
a2 = await g1.create_array("a2", (1,))
685+
g2 = await g1.create_group("g2")
686+
687+
# immediate children
688+
children = sorted([x async for x in group.members()], key=lambda x: x[0])
689+
assert children == [
690+
("a0", a0),
691+
("g0", g0),
692+
]
693+
694+
nmembers = await group.nmembers()
695+
assert nmembers == 2
696+
697+
# partial
698+
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
699+
expected = [
700+
("a0", a0),
701+
("g0", g0),
702+
("g0/a1", a1),
703+
("g0/g1", g1),
704+
]
705+
assert children == expected
706+
nmembers = await group.nmembers(max_depth=1)
707+
assert nmembers == 4
708+
709+
# all children
710+
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
711+
expected = [
712+
("a0", a0),
713+
("g0", g0),
714+
("g0/a1", a1),
715+
("g0/g1", g1),
716+
("g0/g1/a2", a2),
717+
("g0/g1/g2", g2),
718+
]
719+
assert all_children == expected
720+
721+
nmembers = await group.nmembers(max_depth=None)
722+
assert nmembers == 6
723+
724+
with pytest.raises(ValueError, match="max_depth"):
725+
[x async for x in group.members(max_depth=-1)]

tests/v3/test_metadata/test_v3.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import json
34
import re
45
from typing import TYPE_CHECKING, Literal
56

67
from zarr.abc.codec import Codec
78
from zarr.codecs.bytes import BytesCodec
9+
from zarr.core.buffer import default_buffer_prototype
810
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
911

1012
if TYPE_CHECKING:
@@ -230,3 +232,24 @@ def test_metadata_to_dict(
230232
observed.pop("chunk_key_encoding")
231233
expected.pop("chunk_key_encoding")
232234
assert observed == expected
235+
236+
237+
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
238+
@pytest.mark.parametrize("precision", ["ns", "D"])
239+
async def test_datetime_metadata(fill_value: int, precision: str):
240+
metadata_dict = {
241+
"zarr_format": 3,
242+
"node_type": "array",
243+
"shape": (1,),
244+
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
245+
"data_type": f"<M8[{precision}]",
246+
"chunk_key_encoding": {"name": "default", "separator": "."},
247+
"codecs": (),
248+
"fill_value": np.datetime64(fill_value, precision),
249+
}
250+
metadata = ArrayV3Metadata.from_dict(metadata_dict)
251+
# ensure there isn't a TypeError here.
252+
d = metadata.to_buffer_dict(default_buffer_prototype())
253+
254+
result = json.loads(d["zarr.json"].to_bytes())
255+
assert result["fill_value"] == fill_value

tests/v3/test_properties.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ def test_roundtrip(data):
1818

1919

2020
@given(data=st.data())
21+
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
22+
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
23+
# Uncomment the next line to reproduce the original failure.
24+
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=')
25+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
2126
def test_basic_indexing(data):
2227
zarray = data.draw(arrays())
2328
nparray = zarray[:]
@@ -32,6 +37,11 @@ def test_basic_indexing(data):
3237

3338

3439
@given(data=st.data())
40+
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
41+
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
42+
# Uncomment the next line to reproduce the original failure.
43+
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=')
44+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
3545
def test_vindex(data):
3646
zarray = data.draw(arrays())
3747
nparray = zarray[:]

0 commit comments

Comments
 (0)