Skip to content

Commit 22741dd

Browse files
committed
Set GPU config in test
1 parent f90f606 commit 22741dd

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

tests/test_buffer.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,29 @@ async def test_codecs_use_of_gpu_prototype() -> None:
151151
@gpu_test
152152
@pytest.mark.asyncio
153153
async def test_sharding_use_of_gpu_prototype() -> None:
154-
expect = cp.zeros((10, 10), dtype="uint16", order="F")
155-
a = await zarr.api.asynchronous.create_array(
156-
StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype",
157-
shape=expect.shape,
158-
chunks=(5, 5),
159-
shards=(10, 10),
160-
dtype=expect.dtype,
161-
fill_value=0,
162-
)
163-
expect[:] = cp.arange(100).reshape(10, 10)
164-
165-
await a.setitem(
166-
selection=(slice(0, 10), slice(0, 10)),
167-
value=expect[:],
168-
prototype=gpu.buffer_prototype,
169-
)
170-
got = await a.getitem(selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype)
171-
assert isinstance(got, cp.ndarray)
172-
assert cp.array_equal(expect, got)
154+
with zarr.config.enable_gpu():
155+
expect = cp.zeros((10, 10), dtype="uint16", order="F")
156+
157+
a = await zarr.api.asynchronous.create_array(
158+
StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype",
159+
shape=expect.shape,
160+
chunks=(5, 5),
161+
shards=(10, 10),
162+
dtype=expect.dtype,
163+
fill_value=0,
164+
)
165+
expect[:] = cp.arange(100).reshape(10, 10)
166+
167+
await a.setitem(
168+
selection=(slice(0, 10), slice(0, 10)),
169+
value=expect[:],
170+
prototype=gpu.buffer_prototype,
171+
)
172+
got = await a.getitem(
173+
selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype
174+
)
175+
assert isinstance(got, cp.ndarray)
176+
assert cp.array_equal(expect, got)
173177

174178

175179
def test_numpy_buffer_prototype() -> None:

0 commit comments

Comments
 (0)