| 
6 | 6 | import numpy as np  | 
7 | 7 | import pytest  | 
8 | 8 | 
 
  | 
 | 9 | +import zarr  | 
9 | 10 | from zarr import Array, AsyncArray, AsyncGroup, Group  | 
10 | 11 | from zarr.abc.store import Store  | 
11 | 12 | from zarr.core.buffer import default_buffer_prototype  | 
12 | 13 | from zarr.core.common import JSON, ZarrFormat  | 
13 | 14 | from zarr.core.group import GroupMetadata  | 
14 | 15 | from zarr.core.sync import sync  | 
15 | 16 | from zarr.errors import ContainsArrayError, ContainsGroupError  | 
16 |  | -from zarr.store import LocalStore, StorePath  | 
 | 17 | +from zarr.store import LocalStore, MemoryStore, StorePath  | 
17 | 18 | from zarr.store.common import make_store_path  | 
18 | 19 | 
 
  | 
19 | 20 | from .conftest import parse_store  | 
@@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) ->  | 
699 | 700 |     actual = pickle.loads(p)  | 
700 | 701 | 
 
  | 
701 | 702 |     assert actual == expected  | 
 | 703 | + | 
 | 704 | + | 
 | 705 | +async def test_group_members_async(store: LocalStore | MemoryStore) -> None:  | 
 | 706 | +    group = AsyncGroup(  | 
 | 707 | +        GroupMetadata(),  | 
 | 708 | +        store_path=StorePath(store=store, path="root"),  | 
 | 709 | +    )  | 
 | 710 | +    a0 = await group.create_array("a0", shape=(1,))  | 
 | 711 | +    g0 = await group.create_group("g0")  | 
 | 712 | +    a1 = await g0.create_array("a1", shape=(1,))  | 
 | 713 | +    g1 = await g0.create_group("g1")  | 
 | 714 | +    a2 = await g1.create_array("a2", shape=(1,))  | 
 | 715 | +    g2 = await g1.create_group("g2")  | 
 | 716 | + | 
 | 717 | +    # immediate children  | 
 | 718 | +    children = sorted([x async for x in group.members()], key=lambda x: x[0])  | 
 | 719 | +    assert children == [  | 
 | 720 | +        ("a0", a0),  | 
 | 721 | +        ("g0", g0),  | 
 | 722 | +    ]  | 
 | 723 | + | 
 | 724 | +    nmembers = await group.nmembers()  | 
 | 725 | +    assert nmembers == 2  | 
 | 726 | + | 
 | 727 | +    # partial  | 
 | 728 | +    children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])  | 
 | 729 | +    expected = [  | 
 | 730 | +        ("a0", a0),  | 
 | 731 | +        ("g0", g0),  | 
 | 732 | +        ("g0/a1", a1),  | 
 | 733 | +        ("g0/g1", g1),  | 
 | 734 | +    ]  | 
 | 735 | +    assert children == expected  | 
 | 736 | +    nmembers = await group.nmembers(max_depth=1)  | 
 | 737 | +    assert nmembers == 4  | 
 | 738 | + | 
 | 739 | +    # all children  | 
 | 740 | +    all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])  | 
 | 741 | +    expected = [  | 
 | 742 | +        ("a0", a0),  | 
 | 743 | +        ("g0", g0),  | 
 | 744 | +        ("g0/a1", a1),  | 
 | 745 | +        ("g0/g1", g1),  | 
 | 746 | +        ("g0/g1/a2", a2),  | 
 | 747 | +        ("g0/g1/g2", g2),  | 
 | 748 | +    ]  | 
 | 749 | +    assert all_children == expected  | 
 | 750 | + | 
 | 751 | +    nmembers = await group.nmembers(max_depth=None)  | 
 | 752 | +    assert nmembers == 6  | 
 | 753 | + | 
 | 754 | +    with pytest.raises(ValueError, match="max_depth"):  | 
 | 755 | +        [x async for x in group.members(max_depth=-1)]  | 
 | 756 | + | 
 | 757 | + | 
 | 758 | +async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:  | 
 | 759 | +    root = await AsyncGroup.create(store=store, zarr_format=zarr_format)  | 
 | 760 | + | 
 | 761 | +    # create foo group  | 
 | 762 | +    _ = await root.create_group("foo", attributes={"foo": 100})  | 
 | 763 | + | 
 | 764 | +    # test that we can get the group using require_group  | 
 | 765 | +    foo_group = await root.require_group("foo")  | 
 | 766 | +    assert foo_group.attrs == {"foo": 100}  | 
 | 767 | + | 
 | 768 | +    # test that we can get the group using require_group and overwrite=True  | 
 | 769 | +    foo_group = await root.require_group("foo", overwrite=True)  | 
 | 770 | + | 
 | 771 | +    _ = await foo_group.create_array(  | 
 | 772 | +        "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}  | 
 | 773 | +    )  | 
 | 774 | + | 
 | 775 | +    # test that overwriting a group w/ children fails  | 
 | 776 | +    # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array  | 
 | 777 | +    #  | 
 | 778 | +    # with pytest.raises(ContainsArrayError):  | 
 | 779 | +    #     await root.require_group("foo", overwrite=True)  | 
 | 780 | + | 
 | 781 | +    # test that requiring a group where an array is fails  | 
 | 782 | +    with pytest.raises(TypeError):  | 
 | 783 | +        await foo_group.require_group("bar")  | 
 | 784 | + | 
 | 785 | + | 
 | 786 | +async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:  | 
 | 787 | +    root = await AsyncGroup.create(store=store, zarr_format=zarr_format)  | 
 | 788 | +    # create foo group  | 
 | 789 | +    _ = await root.create_group("foo", attributes={"foo": 100})  | 
 | 790 | +    # create bar group  | 
 | 791 | +    _ = await root.create_group("bar", attributes={"bar": 200})  | 
 | 792 | + | 
 | 793 | +    foo_group, bar_group = await root.require_groups("foo", "bar")  | 
 | 794 | +    assert foo_group.attrs == {"foo": 100}  | 
 | 795 | +    assert bar_group.attrs == {"bar": 200}  | 
 | 796 | + | 
 | 797 | +    # get a mix of existing and new groups  | 
 | 798 | +    foo_group, spam_group = await root.require_groups("foo", "spam")  | 
 | 799 | +    assert foo_group.attrs == {"foo": 100}  | 
 | 800 | +    assert spam_group.attrs == {}  | 
 | 801 | + | 
 | 802 | +    # no names  | 
 | 803 | +    no_group = await root.require_groups()  | 
 | 804 | +    assert no_group == ()  | 
 | 805 | + | 
 | 806 | + | 
 | 807 | +async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:  | 
 | 808 | +    root = await AsyncGroup.create(store=store, zarr_format=zarr_format)  | 
 | 809 | +    with pytest.warns(DeprecationWarning):  | 
 | 810 | +        foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")  | 
 | 811 | +    assert foo.shape == (10,)  | 
 | 812 | + | 
 | 813 | +    with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):  | 
 | 814 | +        await root.create_dataset("foo", shape=(100,), dtype="int8")  | 
 | 815 | + | 
 | 816 | +    _ = await root.create_group("bar")  | 
 | 817 | +    with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):  | 
 | 818 | +        await root.create_dataset("bar", shape=(100,), dtype="int8")  | 
 | 819 | + | 
 | 820 | + | 
 | 821 | +async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:  | 
 | 822 | +    root = await AsyncGroup.create(store=store, zarr_format=zarr_format)  | 
 | 823 | +    foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})  | 
 | 824 | +    assert foo1.attrs == {"foo": 101}  | 
 | 825 | +    foo2 = await root.require_array("foo", shape=(10,), dtype="i8")  | 
 | 826 | +    assert foo2.attrs == {"foo": 101}  | 
 | 827 | + | 
 | 828 | +    # exact = False  | 
 | 829 | +    _ = await root.require_array("foo", shape=10, dtype="f8")  | 
 | 830 | + | 
 | 831 | +    # errors w/ exact True  | 
 | 832 | +    with pytest.raises(TypeError, match="Incompatible dtype"):  | 
 | 833 | +        await root.require_array("foo", shape=(10,), dtype="f8", exact=True)  | 
 | 834 | + | 
 | 835 | +    with pytest.raises(TypeError, match="Incompatible shape"):  | 
 | 836 | +        await root.require_array("foo", shape=(100, 100), dtype="i8")  | 
 | 837 | + | 
 | 838 | +    with pytest.raises(TypeError, match="Incompatible dtype"):  | 
 | 839 | +        await root.require_array("foo", shape=(10,), dtype="f4")  | 
 | 840 | + | 
 | 841 | +    _ = await root.create_group("bar")  | 
 | 842 | +    with pytest.raises(TypeError, match="Incompatible object"):  | 
 | 843 | +        await root.require_array("bar", shape=(10,), dtype="int8")  | 
 | 844 | + | 
 | 845 | + | 
 | 846 | +async def test_open_mutable_mapping():  | 
 | 847 | +    group = await zarr.api.asynchronous.open_group(store={}, mode="w")  | 
 | 848 | +    assert isinstance(group.store_path.store, MemoryStore)  | 
 | 849 | + | 
 | 850 | + | 
 | 851 | +def test_open_mutable_mapping_sync():  | 
 | 852 | +    group = zarr.open_group(store={}, mode="w")  | 
 | 853 | +    assert isinstance(group.store_path.store, MemoryStore)  | 
0 commit comments