|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import pickle |
3 | 4 | from typing import TYPE_CHECKING, Any, Literal, cast |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import pytest |
7 | 8 |
|
8 | | -import zarr.api.asynchronous |
9 | 9 | from zarr import Array, AsyncArray, AsyncGroup, Group |
10 | 10 | from zarr.abc.store import Store |
11 | | -from zarr.api.synchronous import open_group |
12 | 11 | from zarr.core.buffer import default_buffer_prototype |
13 | 12 | from zarr.core.common import JSON, ZarrFormat |
14 | 13 | from zarr.core.group import GroupMetadata |
15 | 14 | from zarr.core.sync import sync |
16 | 15 | from zarr.errors import ContainsArrayError, ContainsGroupError |
17 | | -from zarr.store import LocalStore, MemoryStore, StorePath |
| 16 | +from zarr.store import LocalStore, StorePath |
18 | 17 | from zarr.store.common import make_store_path |
19 | 18 |
|
20 | 19 | from .conftest import parse_store |
@@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma |
681 | 680 | assert agroup_new_attributes.attrs == attributes_new |
682 | 681 |
|
683 | 682 |
|
684 | | -async def test_group_members_async(store: LocalStore | MemoryStore) -> None: |
685 | | - group = AsyncGroup( |
686 | | - GroupMetadata(), |
687 | | - store_path=StorePath(store=store, path="root"), |
688 | | - ) |
689 | | - a0 = await group.create_array("a0", shape=(1,)) |
690 | | - g0 = await group.create_group("g0") |
691 | | - a1 = await g0.create_array("a1", shape=(1,)) |
692 | | - g1 = await g0.create_group("g1") |
693 | | - a2 = await g1.create_array("a2", shape=(1,)) |
694 | | - g2 = await g1.create_group("g2") |
695 | | - |
696 | | - # immediate children |
697 | | - children = sorted([x async for x in group.members()], key=lambda x: x[0]) |
698 | | - assert children == [ |
699 | | - ("a0", a0), |
700 | | - ("g0", g0), |
701 | | - ] |
702 | | - |
703 | | - nmembers = await group.nmembers() |
704 | | - assert nmembers == 2 |
705 | | - |
706 | | - # partial |
707 | | - children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) |
708 | | - expected = [ |
709 | | - ("a0", a0), |
710 | | - ("g0", g0), |
711 | | - ("g0/a1", a1), |
712 | | - ("g0/g1", g1), |
713 | | - ] |
714 | | - assert children == expected |
715 | | - nmembers = await group.nmembers(max_depth=1) |
716 | | - assert nmembers == 4 |
717 | | - |
718 | | - # all children |
719 | | - all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) |
720 | | - expected = [ |
721 | | - ("a0", a0), |
722 | | - ("g0", g0), |
723 | | - ("g0/a1", a1), |
724 | | - ("g0/g1", g1), |
725 | | - ("g0/g1/a2", a2), |
726 | | - ("g0/g1/g2", g2), |
727 | | - ] |
728 | | - assert all_children == expected |
729 | | - |
730 | | - nmembers = await group.nmembers(max_depth=None) |
731 | | - assert nmembers == 6 |
732 | | - |
733 | | - with pytest.raises(ValueError, match="max_depth"): |
734 | | - [x async for x in group.members(max_depth=-1)] |
735 | | - |
736 | | - |
737 | | -async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
738 | | - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
739 | | - |
740 | | - # create foo group |
741 | | - _ = await root.create_group("foo", attributes={"foo": 100}) |
742 | | - |
743 | | - # test that we can get the group using require_group |
744 | | - foo_group = await root.require_group("foo") |
745 | | - assert foo_group.attrs == {"foo": 100} |
746 | | - |
747 | | - # test that we can get the group using require_group and overwrite=True |
748 | | - foo_group = await root.require_group("foo", overwrite=True) |
749 | | - |
750 | | - _ = await foo_group.create_array( |
751 | | - "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} |
| 683 | +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) |
| 684 | +@pytest.mark.parametrize("zarr_format", (2, 3)) |
| 685 | +async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: |
| 686 | + expected = await AsyncGroup.create( |
| 687 | + store=store, attributes={"foo": 999}, zarr_format=zarr_format |
752 | 688 | ) |
| 689 | + p = pickle.dumps(expected) |
| 690 | + actual = pickle.loads(p) |
| 691 | + assert actual == expected |
753 | 692 |
|
754 | | - # test that overwriting a group w/ children fails |
755 | | - # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array |
756 | | - # |
757 | | - # with pytest.raises(ContainsArrayError): |
758 | | - # await root.require_group("foo", overwrite=True) |
759 | | - |
760 | | - # test that requiring a group where an array is fails |
761 | | - with pytest.raises(TypeError): |
762 | | - await foo_group.require_group("bar") |
763 | | - |
764 | | - |
765 | | -async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
766 | | - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
767 | | - # create foo group |
768 | | - _ = await root.create_group("foo", attributes={"foo": 100}) |
769 | | - # create bar group |
770 | | - _ = await root.create_group("bar", attributes={"bar": 200}) |
771 | | - |
772 | | - foo_group, bar_group = await root.require_groups("foo", "bar") |
773 | | - assert foo_group.attrs == {"foo": 100} |
774 | | - assert bar_group.attrs == {"bar": 200} |
775 | | - |
776 | | - # get a mix of existing and new groups |
777 | | - foo_group, spam_group = await root.require_groups("foo", "spam") |
778 | | - assert foo_group.attrs == {"foo": 100} |
779 | | - assert spam_group.attrs == {} |
780 | | - |
781 | | - # no names |
782 | | - no_group = await root.require_groups() |
783 | | - assert no_group == () |
784 | | - |
785 | | - |
786 | | -async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
787 | | - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
788 | | - with pytest.warns(DeprecationWarning): |
789 | | - foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") |
790 | | - assert foo.shape == (10,) |
791 | | - |
792 | | - with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): |
793 | | - await root.create_dataset("foo", shape=(100,), dtype="int8") |
794 | | - |
795 | | - _ = await root.create_group("bar") |
796 | | - with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): |
797 | | - await root.create_dataset("bar", shape=(100,), dtype="int8") |
798 | | - |
799 | | - |
800 | | -async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
801 | | - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
802 | | - foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) |
803 | | - assert foo1.attrs == {"foo": 101} |
804 | | - foo2 = await root.require_array("foo", shape=(10,), dtype="i8") |
805 | | - assert foo2.attrs == {"foo": 101} |
806 | | - |
807 | | - # exact = False |
808 | | - _ = await root.require_array("foo", shape=10, dtype="f8") |
809 | | - |
810 | | - # errors w/ exact True |
811 | | - with pytest.raises(TypeError, match="Incompatible dtype"): |
812 | | - await root.require_array("foo", shape=(10,), dtype="f8", exact=True) |
813 | | - |
814 | | - with pytest.raises(TypeError, match="Incompatible shape"): |
815 | | - await root.require_array("foo", shape=(100, 100), dtype="i8") |
816 | | - |
817 | | - with pytest.raises(TypeError, match="Incompatible dtype"): |
818 | | - await root.require_array("foo", shape=(10,), dtype="f4") |
819 | | - |
820 | | - _ = await root.create_group("bar") |
821 | | - with pytest.raises(TypeError, match="Incompatible object"): |
822 | | - await root.require_array("bar", shape=(10,), dtype="int8") |
823 | | - |
824 | | - |
825 | | -async def test_open_mutable_mapping(): |
826 | | - group = await zarr.api.asynchronous.open_group(store={}, mode="w") |
827 | | - assert isinstance(group.store_path.store, MemoryStore) |
828 | 693 |
|
| 694 | +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) |
| 695 | +@pytest.mark.parametrize("zarr_format", (2, 3)) |
| 696 | +def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: |
| 697 | + expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) |
| 698 | + p = pickle.dumps(expected) |
| 699 | + actual = pickle.loads(p) |
829 | 700 |
|
830 | | -def test_open_mutable_mapping_sync(): |
831 | | - group = open_group(store={}, mode="w") |
832 | | - assert isinstance(group.store_path.store, MemoryStore) |
| 701 | + assert actual == expected |
0 commit comments