Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec:
config=ArrayConfig(
order="C", write_empty_chunks=False
), # Note: this is hard-coded for simplicity -- it is not surfaced into user code,
prototype=numpy_buffer_prototype(),
prototype=default_buffer_prototype(),
)

def _get_chunk_spec(self, shard_spec: ArraySpec) -> ArraySpec:
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def has_cupy() -> bool:
# Decorator for GPU tests
def gpu_test(func: T_Callable) -> T_Callable:
return cast(
T_Callable,
"T_Callable",
pytest.mark.gpu(
pytest.mark.skipif(not has_cupy(), reason="CuPy not installed or no GPU available")(
func
Expand Down
38 changes: 38 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,34 @@ async def test_codecs_use_of_gpu_prototype() -> None:
assert cp.array_equal(expect, got)


@gpu_test
@pytest.mark.asyncio
async def test_sharding_use_of_gpu_prototype() -> None:
with zarr.config.enable_gpu():
expect = cp.zeros((10, 10), dtype="uint16", order="F")

a = await zarr.api.asynchronous.create_array(
StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype",
shape=expect.shape,
chunks=(5, 5),
shards=(10, 10),
dtype=expect.dtype,
fill_value=0,
)
expect[:] = cp.arange(100).reshape(10, 10)

await a.setitem(
selection=(slice(0, 10), slice(0, 10)),
value=expect[:],
prototype=gpu.buffer_prototype,
)
got = await a.getitem(
selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype
)
assert isinstance(got, cp.ndarray)
assert cp.array_equal(expect, got)


def test_numpy_buffer_prototype() -> None:
buffer = cpu.buffer_prototype.buffer.create_zero_length()
ndbuffer = cpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64"))
Expand All @@ -157,6 +185,16 @@ def test_numpy_buffer_prototype() -> None:
ndbuffer.as_scalar()


@gpu_test
def test_gpu_buffer_prototype() -> None:
buffer = gpu.buffer_prototype.buffer.create_zero_length()
ndbuffer = gpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=cp.dtype("int64"))
assert isinstance(buffer.as_array_like(), cp.ndarray)
assert isinstance(ndbuffer.as_ndarray_like(), cp.ndarray)
with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"):
ndbuffer.as_scalar()


# TODO: the same test for other buffer classes
def test_cpu_buffer_as_scalar() -> None:
buf = cpu.buffer_prototype.nd_buffer.create(shape=(), dtype="int64")
Expand Down