Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,20 +424,43 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
def __repr__(self) -> str:
return f"<AsyncGroup {self.store_path}>"

async def nmembers(self) -> int:
async def nmembers(self, recursive: bool = False) -> int:
"""
Count the number of members in this group.

Parameters
----------
recursive : bool, default False
Whether to recursively count arrays and groups in child groups of
this Group. By default, just immediate child array and group members
are counted.

Returns
-------
count : int
"""
# TODO: consider using aioitertools.builtins.sum for this
# return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0)
n = 0
async for _ in self.members():
async for _ in self.members(recursive=recursive):
n += 1
return n

async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
async def members(
self, recursive: bool = False
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
This method requires that `store_path.store` supports directory listing.

The results are not guaranteed to be ordered.

Parameters
----------
recursive : bool, default False
Whether to recursively include arrays and groups in child groups of
this Group. By default, just immediate child array and group members
are included.
"""
if not self.store_path.store.supports_listing:
msg = (
Expand All @@ -456,7 +479,19 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
if key in _skip_keys:
continue
try:
yield (key, await self.getitem(key))
obj = await self.getitem(key)
yield (key, obj)

if (
recursive
and hasattr(obj.metadata, "node_type")
and obj.metadata.node_type == "group"
):
# the assert is just for mypy to know that `obj.metadata.node_type`
# implies an AsyncGroup, not an AsyncArray
assert isinstance(obj, AsyncGroup)
async for child_key, val in obj.members(recursive=recursive):
yield "/".join([key, child_key]), val
except KeyError:
# keyerror is raised when `key` names an object (in the object storage sense),
# as opposed to a prefix, in the store under the prefix associated with this group
Expand Down Expand Up @@ -628,17 +663,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nmembers(self) -> int:
return self._sync(self._async_group.nmembers())
def nmembers(self, recursive: bool = False) -> int:
return self._sync(self._async_group.nmembers(recursive=recursive))

@property
def members(self) -> tuple[tuple[str, Array | Group], ...]:
def members(self, recursive: bool = False) -> tuple[tuple[str, Array | Group], ...]:
"""
Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group)
pairs
"""
_members = self._sync_iter(self._async_group.members())
_members = self._sync_iter(self._async_group.members(recursive=recursive))

result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members))
return result
Expand Down
22 changes: 15 additions & 7 deletions src/zarr/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,21 @@ def _json_convert(o: Any) -> Any:
if isinstance(o, np.dtype):
return str(o)
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
out: Any
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
# https://github.com/zarr-developers/zarr-python/issues/2119
# `.item()` on a datetime type might or might not return an
# integer, depending on the value.
# Explicitly cast to an int first, and then grab .item()
out = o.view("i8").item()
else:
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return out
if isinstance(o, Enum):
return o.name
Expand Down
14 changes: 11 additions & 3 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
for key in keys_unique:
yield key
else:
for key in self._store_dict:
if key.startswith(prefix + "/") and key != prefix:
yield key.removeprefix(prefix + "/").split("/")[0]
# Our dictionary doesn't contain directory markers, but we want to include
# a pseudo directory when there's a nested item and we're listing an
# intermediate level.
n = prefix.count("/") + 2
keys_unique = {
"/".join(k.split("/", n)[:n])
for k in self._store_dict
if k.startswith(prefix + "/")
}
for key in keys_unique:
yield key.removeprefix(prefix + "/").split("/")[0]
2 changes: 1 addition & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
except FileNotFoundError:
return
for onefile in (a.replace(prefix + "/", "") for a in allfiles):
yield onefile
yield onefile.removeprefix(self.path).removeprefix("/")

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for onefile in await self._fs._ls(prefix, detail=False):
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ async def test_list_dir(self, store: S) -> None:
assert [k async for k in store.list_dir("foo")] == []
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
await store.set("foo/c/1", Buffer.from_bytes(b"\x01"))
await store.set("foo/c/d/1", Buffer.from_bytes(b"\x01"))
await store.set("foo/c/d/2", Buffer.from_bytes(b"\x01"))
await store.set("foo/c/d/3", Buffer.from_bytes(b"\x01"))

keys_expected = ["foo"]
keys_observed = [k async for k in store.list_dir("")]
assert set(keys_observed) == set(keys_expected), keys_observed

keys_expected = ["zarr.json", "c"]
keys_observed = [k async for k in store.list_dir("foo")]
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

import hypothesis.extra.numpy as npst
Expand Down Expand Up @@ -101,7 +102,14 @@ def arrays(
root = Group.create(store)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
fill_value_args = ("ns",)
m = re.search(r"\[(.+)\]", nparray.dtype.str)
if not m:
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")

fill_value_args = (
# e.g. ns, D
m.groups()[0],
)

a = root.create_array(
array_path,
Expand Down
52 changes: 48 additions & 4 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
members_expected["subgroup"] = group.create_group("subgroup")
# make a sub-sub-subgroup, to ensure that the children calculation doesn't go
# too deep in the hierarchy
_ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore

members_expected["subarray"] = group.create_array(
"subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True
Expand All @@ -101,7 +101,13 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
# this creates a directory with a random key in it
# this should not show up as a member
sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000")))
members_observed = group.members
members_observed = group.members()
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

# recursive=True
members_observed = group.members(recursive=True)
members_expected["subgroup/subsubgroup"] = subsubgroup
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

Expand Down Expand Up @@ -349,7 +355,8 @@ def test_group_create_array(
if method == "create_array":
array = group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
with pytest.warns(DeprecationWarning):
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
else:
raise AssertionError

Expand All @@ -358,7 +365,7 @@ def test_group_create_array(
with pytest.raises(ContainsArrayError):
group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
with pytest.raises(ContainsArrayError):
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
group.array(name="array", shape=shape, dtype=dtype, data=data)
assert array.shape == shape
assert array.dtype == np.dtype(dtype)
Expand Down Expand Up @@ -653,3 +660,40 @@ async def test_asyncgroup_update_attributes(

agroup_new_attributes = await agroup.update_attributes(attributes_new)
assert agroup_new_attributes.attrs == attributes_new


async def test_group_members_async(store: LocalStore | MemoryStore):
group = AsyncGroup(
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", (1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", (1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", (1,))
g2 = await g1.create_group("g2")

# immediate children
children = sorted([x async for x in group.members()], key=lambda x: x[0])
assert children == [
("a0", a0),
("g0", g0),
]

nmembers = await group.nmembers()
assert nmembers == 2

all_children = sorted([x async for x in group.members(recursive=True)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
("g0/g1/a2", a2),
("g0/g1/g2", g2),
]
assert all_children == expected

nmembers = await group.nmembers(recursive=True)
assert nmembers == 6
23 changes: 23 additions & 0 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Literal

from zarr.abc.codec import Codec
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding

if TYPE_CHECKING:
Expand Down Expand Up @@ -230,3 +232,24 @@ def test_metadata_to_dict(
observed.pop("chunk_key_encoding")
expected.pop("chunk_key_encoding")
assert observed == expected


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
@pytest.mark.parametrize("precision", ["ns", "D"])
async def test_datetime_metadata(fill_value: int, precision: str):
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": f"<M8[{precision}]",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": np.datetime64(fill_value, precision),
}
metadata = ArrayV3Metadata.from_dict(metadata_dict)
# ensure there isn't a TypeError here.
d = metadata.to_buffer_dict(default_buffer_prototype())

result = json.loads(d["zarr.json"].to_bytes())
assert result["fill_value"] == fill_value
10 changes: 10 additions & 0 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_roundtrip(data):


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_basic_indexing(data):
zarray = data.draw(arrays())
nparray = zarray[:]
Expand All @@ -32,6 +37,11 @@ def test_basic_indexing(data):


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian - would you mind looking into this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a minefield 😬 The deepest I got was hypothesis.extra.numpy:ArrayStrategy.set_element. Something about the sequence of operations and the data passed into the ndarray made it "weird", such that array == np.complex128(0.0) raised an InvalidComparision warning. I'm not sure how to interpret the bytes at #2118 (comment).

def test_vindex(data):
zarray = data.draw(arrays())
nparray = zarray[:]
Expand Down