Skip to content

Commit 61f7008

Browse files
committed
test for kwarg propagation through array-like routines
1 parent 5ddd68c commit 61f7008

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

tests/test_api.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,74 @@ def test_create(memory_store: Store) -> None:
6868
z = create(shape=(400, 100), chunks=(16, 16.5), store=store, overwrite=True) # type: ignore [arg-type]
6969

7070

71+
@pytest.mark.parametrize("func_name", ["zeros_like", "ones_like", "empty_like", "full_like"])
72+
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
73+
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
74+
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
75+
async def test_array_like_creation(
76+
zarr_format: ZarrFormat,
77+
func_name: str,
78+
out_shape: Literal["keep"] | tuple[int, ...],
79+
out_chunks: Literal["keep"] | tuple[int, ...],
80+
out_dtype: str,
81+
) -> None:
82+
"""
83+
Test zeros_like, ones_like, empty_like, full_like, ensuring that we can override the
84+
shape, chunks, and dtype of the array-like object provided to these functions with
85+
appropriate keyword arguments
86+
"""
87+
ref_arr = zarr.ones(
88+
store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12), zarr_format=zarr_format
89+
)
90+
kwargs: dict[str, object] = {}
91+
if func_name == "full_like":
92+
expect_fill = 4
93+
kwargs["fill_value"] = expect_fill
94+
func = zarr.api.asynchronous.full_like
95+
elif func_name == "zeros_like":
96+
expect_fill = 0
97+
func = zarr.api.asynchronous.zeros_like
98+
elif func_name == "ones_like":
99+
expect_fill = 1
100+
func = zarr.api.asynchronous.ones_like
101+
elif func_name == "empty_like":
102+
expect_fill = ref_arr.fill_value
103+
func = zarr.api.asynchronous.empty_like
104+
else:
105+
raise AssertionError
106+
if out_shape != "keep":
107+
kwargs["shape"] = out_shape
108+
expect_shape = out_shape
109+
else:
110+
expect_shape = ref_arr.shape
111+
if out_chunks != "keep":
112+
kwargs["chunks"] = out_chunks
113+
expect_chunks = out_chunks
114+
else:
115+
expect_chunks = ref_arr.chunks
116+
if out_dtype != "keep":
117+
kwargs["dtype"] = out_dtype
118+
expect_dtype = out_dtype
119+
else:
120+
expect_dtype = ref_arr.dtype # type: ignore[assignment]
121+
122+
new_arr = await func(ref_arr, path="foo", **kwargs)
123+
assert new_arr.shape == expect_shape
124+
assert new_arr.chunks == expect_chunks
125+
assert new_arr.dtype == expect_dtype
126+
assert np.all(Array(new_arr)[:] == expect_fill)
127+
128+
129+
async def test_invalid_full_like() -> None:
130+
"""
131+
Test that a fill value that is incompatible with the proposed dtype is rejected
132+
"""
133+
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
134+
fill = 4
135+
with pytest.raises(ValueError, match=f"fill value {fill} is not valid for dtype DataType.bool"):
136+
await zarr.api.asynchronous.full_like(ref_arr, path="foo", fill_value=fill, dtype="bool")
137+
138+
71139
# TODO: parametrize over everything this function takes
72140
@pytest.mark.parametrize("store", ["memory"], indirect=True)
73141
def test_create_array(store: Store) -> None:

tests/test_group.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,63 @@ def test_group_create_array(
668668
assert np.array_equal(array[:], data)
669669

670670

671+
@pytest.mark.parametrize("method_name", ["zeros_like", "ones_like", "empty_like", "full_like"])
672+
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
673+
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
674+
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
675+
def test_group_array_like_creation(
676+
zarr_format: ZarrFormat,
677+
method_name: str,
678+
out_shape: Literal["keep"] | tuple[int, ...],
679+
out_chunks: Literal["keep"] | tuple[int, ...],
680+
out_dtype: str,
681+
) -> None:
682+
"""
683+
Test Group.{zeros_like, ones_like, empty_like, full_like}, ensuring that we can override the
684+
shape, chunks, and dtype of the array-like object provided to these functions with
685+
appropriate keyword arguments
686+
"""
687+
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
688+
group = Group.from_store({}, zarr_format=zarr_format)
689+
kwargs = {}
690+
if method_name == "full_like":
691+
expect_fill = 4
692+
kwargs["fill_value"] = expect_fill
693+
meth = group.full_like
694+
elif method_name == "zeros_like":
695+
expect_fill = 0
696+
meth = group.zeros_like
697+
elif method_name == "ones_like":
698+
expect_fill = 1
699+
meth = group.ones_like
700+
elif method_name == "empty_like":
701+
expect_fill = ref_arr.fill_value
702+
meth = group.empty_like
703+
else:
704+
raise AssertionError
705+
if out_shape != "keep":
706+
kwargs["shape"] = out_shape
707+
expect_shape = out_shape
708+
else:
709+
expect_shape = ref_arr.shape
710+
if out_chunks != "keep":
711+
kwargs["chunks"] = out_chunks
712+
expect_chunks = out_chunks
713+
else:
714+
expect_chunks = ref_arr.chunks
715+
if out_dtype != "keep":
716+
kwargs["dtype"] = out_dtype
717+
expect_dtype = out_dtype
718+
else:
719+
expect_dtype = ref_arr.dtype
720+
721+
new_arr = meth(name="foo", data=ref_arr, **kwargs)
722+
assert new_arr.shape == expect_shape
723+
assert new_arr.chunks == expect_chunks
724+
assert new_arr.dtype == expect_dtype
725+
assert np.all(new_arr[:] == expect_fill)
726+
727+
671728
def test_group_array_creation(
672729
store: Store,
673730
zarr_format: ZarrFormat,

0 commit comments

Comments
 (0)