Skip to content

Commit 94e460e

Browse files
committed
Update GPU handling
This updates how we handle GPU buffers. See the new docs page for a simple example. The basic idea, as discussed in ..., is to use host buffers for all metadata objects and device buffers for data. Zarr has two types of buffers: plain buffers (used for a stream of bytes) and ndbuffers (used for bytes that represent ndarrays). To make it easier for users, I've added a new config option `zarr.config.enable_gpu()` that can be used to update those both. If we need additional customizations in the future, we can add them here.
1 parent 0c154c3 commit 94e460e

File tree

7 files changed

+91
-9
lines changed

7 files changed

+91
-9
lines changed

docs/user-guide/gpu.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
.. _user-guide-gpu:
2+
3+
Using GPUs with Zarr
4+
====================
5+
6+
Zarr can be used along with GPUs to accelerate your workload. Currently,
7+
Zarr supports reading data into GPU memory. In the future, Zarr will
8+
support GPU-accelerated codecs and file IO.
9+
10+
Reading data into device memory
11+
-------------------------------
12+
13+
.. code-block:: python
14+
15+
>>> import zarr
16+
>>> import cupy as cp
17+
>>> zarr.config.enable_cuda()
18+
>>> store = zarr.storage.MemoryStore()
19+
>>> type(z[:10, :10])
20+
cupy.ndarray
21+
22+
:meth:`zarr.config.enable_cuda` updates the Zarr configuration to use device
23+
memory for all data buffers used by Zarr. This means that any reads from a Zarr
24+
store will return a CuPy ndarray rather than a NumPy ndarray. Any buffers used
25+
for metadata will be on the host.

docs/user-guide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Advanced Topics
2323
performance
2424
consolidated_metadata
2525
extending
26+
gpu
2627

2728

2829
.. Coming soon

src/zarr/core/array.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
NDBuffer,
3939
default_buffer_prototype,
4040
)
41+
from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype
4142
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
4243
from zarr.core.chunk_key_encodings import (
4344
ChunkKeyEncoding,
@@ -163,19 +164,20 @@ async def get_array_metadata(
163164
) -> dict[str, JSON]:
164165
if zarr_format == 2:
165166
zarray_bytes, zattrs_bytes = await gather(
166-
(store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get()
167+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
168+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
167169
)
168170
if zarray_bytes is None:
169171
raise FileNotFoundError(store_path)
170172
elif zarr_format == 3:
171-
zarr_json_bytes = await (store_path / ZARR_JSON).get()
173+
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
172174
if zarr_json_bytes is None:
173175
raise FileNotFoundError(store_path)
174176
elif zarr_format is None:
175177
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
176-
(store_path / ZARR_JSON).get(),
177-
(store_path / ZARRAY_JSON).get(),
178-
(store_path / ZATTRS_JSON).get(),
178+
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
179+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
180+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
179181
)
180182
if zarr_json_bytes is not None and zarray_bytes is not None:
181183
# warn and favor v3
@@ -1295,7 +1297,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
12951297
"""
12961298
Asynchronously save the array metadata.
12971299
"""
1298-
to_save = metadata.to_buffer_dict(default_buffer_prototype())
1300+
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
12991301
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
13001302

13011303
if ensure_parents:
@@ -1307,7 +1309,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
13071309
[
13081310
(parent.store_path / key).set_if_not_exists(value)
13091311
for key, value in parent.metadata.to_buffer_dict(
1310-
default_buffer_prototype()
1312+
cpu_buffer_prototype
13111313
).items()
13121314
]
13131315
)

src/zarr/core/buffer/gpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
from zarr.core.buffer import core
1515
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike
16+
from zarr.registry import (
17+
register_buffer,
18+
register_ndbuffer,
19+
)
1620

1721
if TYPE_CHECKING:
1822
from collections.abc import Iterable
@@ -215,3 +219,6 @@ def __setitem__(self, key: Any, value: Any) -> None:
215219

216220

217221
buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
222+
223+
register_buffer(Buffer)
224+
register_ndbuffer(NDBuffer)

src/zarr/core/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929

3030
from __future__ import annotations
3131

32-
from typing import Any, Literal, cast
32+
from typing import TYPE_CHECKING, Any, Literal, cast
3333

3434
from donfig import Config as DConfig
3535

36+
if TYPE_CHECKING:
37+
from donfig.config_obj import ConfigSet
38+
3639

3740
class BadConfigError(ValueError):
3841
_msg = "bad Config: %r"
@@ -56,6 +59,14 @@ def reset(self) -> None:
5659
self.clear()
5760
self.refresh()
5861

62+
def enable_gpu(self) -> ConfigSet:
63+
"""
64+
Configure Zarr to use GPUs where possible.
65+
"""
66+
return self.set(
67+
{"buffer": "zarr.core.buffer.gpu.Buffer", "ndbuffer": "zarr.core.buffer.gpu.NDBuffer"}
68+
)
69+
5970

6071
# The default configuration for zarr
6172
config = Config(

src/zarr/testing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def has_cupy() -> bool:
3838
return False
3939

4040

41-
T_Callable = TypeVar("T_Callable", bound=Callable[[], Coroutine[Any, Any, None]])
41+
T_Callable = TypeVar("T_Callable", bound=Callable[..., Coroutine[Any, Any, None] | None])
4242

4343

4444
# Decorator for GPU tests

tests/test_api.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from zarr.errors import MetadataValidationError
2828
from zarr.storage import MemoryStore
2929
from zarr.storage._utils import normalize_path
30+
from zarr.testing.utils import gpu_test
3031

3132

3233
def test_create(memory_store: Store) -> None:
@@ -1121,3 +1122,38 @@ def test_open_array_with_mode_r_plus(store: Store) -> None:
11211122
assert isinstance(z2, Array)
11221123
assert (z2[:] == 1).all()
11231124
z2[:] = 3
1125+
1126+
1127+
@gpu_test
1128+
@pytest.mark.parametrize(
1129+
"store",
1130+
["local", "memory", "zip"],
1131+
indirect=True,
1132+
)
1133+
@pytest.mark.parametrize("zarr_format", [None, 2, 3])
1134+
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None) -> None:
1135+
import cupy as cp
1136+
1137+
if zarr_format == 2:
1138+
# Without this, the zstd codec attempts to convert the cupy
1139+
# array to bytes.
1140+
compressors = None
1141+
else:
1142+
compressors = "auto"
1143+
1144+
with zarr.config.enable_gpu():
1145+
src = cp.random.uniform(size=(100, 100)) # allocate on the device
1146+
z = zarr.create_array(
1147+
store,
1148+
name="a",
1149+
shape=src.shape,
1150+
chunks=(10, 10),
1151+
dtype=src.dtype,
1152+
overwrite=True,
1153+
zarr_format=zarr_format,
1154+
compressors=compressors,
1155+
)
1156+
z[:10, :10] = src[:10, :10]
1157+
1158+
result = z[:10, :10]
1159+
cp.testing.assert_array_equal(result, src[:10, :10])

0 commit comments

Comments
 (0)