Skip to content

Commit f90f606

Browse files
committed
Add more unit tests for GPU buffer
1 parent 018f61d commit f90f606

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/zarr/testing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def has_cupy() -> bool:
4444
# Decorator for GPU tests
4545
def gpu_test(func: T_Callable) -> T_Callable:
4646
return cast(
47-
T_Callable,
47+
"T_Callable",
4848
pytest.mark.gpu(
4949
pytest.mark.skipif(not has_cupy(), reason="CuPy not installed or no GPU available")(
5050
func

tests/test_buffer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,44 @@ async def test_codecs_use_of_gpu_prototype() -> None:
148148
assert cp.array_equal(expect, got)
149149

150150

151+
@gpu_test
152+
@pytest.mark.asyncio
153+
async def test_sharding_use_of_gpu_prototype() -> None:
154+
expect = cp.zeros((10, 10), dtype="uint16", order="F")
155+
a = await zarr.api.asynchronous.create_array(
156+
StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype",
157+
shape=expect.shape,
158+
chunks=(5, 5),
159+
shards=(10, 10),
160+
dtype=expect.dtype,
161+
fill_value=0,
162+
)
163+
expect[:] = cp.arange(100).reshape(10, 10)
164+
165+
await a.setitem(
166+
selection=(slice(0, 10), slice(0, 10)),
167+
value=expect[:],
168+
prototype=gpu.buffer_prototype,
169+
)
170+
got = await a.getitem(selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype)
171+
assert isinstance(got, cp.ndarray)
172+
assert cp.array_equal(expect, got)
173+
174+
151175
def test_numpy_buffer_prototype() -> None:
152176
buffer = cpu.buffer_prototype.buffer.create_zero_length()
153177
ndbuffer = cpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64"))
154178
assert isinstance(buffer.as_array_like(), np.ndarray)
155179
assert isinstance(ndbuffer.as_ndarray_like(), np.ndarray)
156180
with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"):
157181
ndbuffer.as_scalar()
182+
183+
184+
@gpu_test
185+
def test_gpu_buffer_prototype() -> None:
186+
buffer = gpu.buffer_prototype.buffer.create_zero_length()
187+
ndbuffer = gpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=cp.dtype("int64"))
188+
assert isinstance(buffer.as_array_like(), cp.ndarray)
189+
assert isinstance(ndbuffer.as_ndarray_like(), cp.ndarray)
190+
with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"):
191+
ndbuffer.as_scalar()

0 commit comments

Comments
 (0)