Skip to content

Commit 24ef221

Browse files
TomAugspurgerd-v-b
andauthored
Update and document GPU buffer handling (#2751)
* Update GPU handling This updates how we handle GPU buffers. See the new docs page for a simple example. The basic idea, as discussed in ..., is to use host buffers for all metadata objects and device buffers for data. Zarr has two types of buffers: plain buffers (used for a stream of bytes) and ndbuffers (used for bytes that represent ndarrays). To make it easier for users, I've added a new config option `zarr.config.enable_gpu()` that can be used to update those both. If we need additional customizations in the future, we can add them here. * fixed doc * Fixup * changelog * doctest, skip * removed not gpu * assert that the type matches * Added changelog notes --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent 47003d7 commit 24ef221

File tree

12 files changed

+130
-9
lines changed

12 files changed

+130
-9
lines changed

changes/2751.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed bug with Zarr using device memory, instead of host memory, for storing metadata when using GPUs.

changes/2751.doc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added new user guide on :ref:`user-guide-gpu`.

changes/2751.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added :meth:`zarr.config.enable_gpu` to update Zarr's configuration to use GPUs.

docs/developers/contributing.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,27 @@ during development at `http://0.0.0.0:8000/ <http://0.0.0.0:8000/>`_. This can b
230230

231231
$ hatch --env docs run serve
232232

233+
.. _changelog:
234+
235+
Changelog
236+
~~~~~~~~~
237+
238+
zarr-python uses `towncrier`_ to manage release notes. Most pull requests should
239+
include at least one news fragment describing the changes. To add a release
240+
note, you'll need the GitHub issue or pull request number and the type of your
241+
change (``feature``, ``bugfix``, ``doc``, ``removal``, ``misc``). With that, run
242+
```towncrier create``` with your development environment, which will prompt you
243+
for the issue number, change type, and the news text::
244+
245+
towncrier create
246+
247+
Alternatively, you can manually create the files in the ``changes`` directory
248+
using the naming convention ``{issue-number}.{change-type}.rst``.
249+
250+
See the `towncrier`_ docs for more.
251+
252+
.. _towncrier: https://towncrier.readthedocs.io/en/stable/tutorial.html
253+
233254
Development best practices, policies and procedures
234255
---------------------------------------------------
235256

docs/user-guide/config.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Configuration options include the following:
3232
- Whether empty chunks are written to storage ``array.write_empty_chunks``
3333
- Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers``
3434
- Selections of implementations of codecs, codec pipelines and buffers
35+
- Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more.
3536

3637
For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
3738
first register the implementations in the registry and then select them in the config.

docs/user-guide/gpu.rst

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
.. _user-guide-gpu:
2+
3+
Using GPUs with Zarr
4+
====================
5+
6+
Zarr can use GPUs to accelerate your workload by running
7+
:meth:`zarr.config.enable_gpu`.
8+
9+
.. note::
10+
11+
`zarr-python` currently supports reading the ndarray data into device (GPU)
12+
memory as the final stage of the codec pipeline. Data will still be read into
13+
or copied to host (CPU) memory for encoding and decoding.
14+
15+
In the future, codecs will be available compressing and decompressing data on
16+
the GPU, avoiding the need to move data between the host and device for
17+
compression and decompression.
18+
19+
Reading data into device memory
20+
-------------------------------
21+
22+
:meth:`zarr.config.enable_gpu` configures Zarr to use GPU memory for the data
23+
buffers used internally by Zarr.
24+
25+
.. code-block:: python
26+
27+
>>> import zarr
28+
>>> import cupy as cp # doctest: +SKIP
29+
>>> zarr.config.enable_gpu() # doctest: +SKIP
30+
>>> store = zarr.storage.MemoryStore() # doctest: +SKIP
31+
>>> z = zarr.create_array( # doctest: +SKIP
32+
... store=store, shape=(100, 100), chunks=(10, 10), dtype="float32",
33+
... )
34+
>>> type(z[:10, :10]) # doctest: +SKIP
35+
cupy.ndarray
36+
37+
Note that the output type is a ``cupy.ndarray`` rather than a NumPy array.

docs/user-guide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Advanced Topics
2323
performance
2424
consolidated_metadata
2525
extending
26+
gpu
2627

2728

2829
.. Coming soon

src/zarr/core/array.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
NDBuffer,
3939
default_buffer_prototype,
4040
)
41+
from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype
4142
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
4243
from zarr.core.chunk_key_encodings import (
4344
ChunkKeyEncoding,
@@ -163,19 +164,20 @@ async def get_array_metadata(
163164
) -> dict[str, JSON]:
164165
if zarr_format == 2:
165166
zarray_bytes, zattrs_bytes = await gather(
166-
(store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get()
167+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
168+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
167169
)
168170
if zarray_bytes is None:
169171
raise FileNotFoundError(store_path)
170172
elif zarr_format == 3:
171-
zarr_json_bytes = await (store_path / ZARR_JSON).get()
173+
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
172174
if zarr_json_bytes is None:
173175
raise FileNotFoundError(store_path)
174176
elif zarr_format is None:
175177
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
176-
(store_path / ZARR_JSON).get(),
177-
(store_path / ZARRAY_JSON).get(),
178-
(store_path / ZATTRS_JSON).get(),
178+
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
179+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
180+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
179181
)
180182
if zarr_json_bytes is not None and zarray_bytes is not None:
181183
# warn and favor v3
@@ -1348,7 +1350,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
13481350
"""
13491351
Asynchronously save the array metadata.
13501352
"""
1351-
to_save = metadata.to_buffer_dict(default_buffer_prototype())
1353+
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
13521354
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
13531355

13541356
if ensure_parents:
@@ -1360,7 +1362,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
13601362
[
13611363
(parent.store_path / key).set_if_not_exists(value)
13621364
for key, value in parent.metadata.to_buffer_dict(
1363-
default_buffer_prototype()
1365+
cpu_buffer_prototype
13641366
).items()
13651367
]
13661368
)

src/zarr/core/buffer/gpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
from zarr.core.buffer import core
1515
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike
16+
from zarr.registry import (
17+
register_buffer,
18+
register_ndbuffer,
19+
)
1620

1721
if TYPE_CHECKING:
1822
from collections.abc import Iterable
@@ -215,3 +219,6 @@ def __setitem__(self, key: Any, value: Any) -> None:
215219

216220

217221
buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
222+
223+
register_buffer(Buffer)
224+
register_ndbuffer(NDBuffer)

src/zarr/core/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929

3030
from __future__ import annotations
3131

32-
from typing import Any, Literal, cast
32+
from typing import TYPE_CHECKING, Any, Literal, cast
3333

3434
from donfig import Config as DConfig
3535

36+
if TYPE_CHECKING:
37+
from donfig.config_obj import ConfigSet
38+
3639

3740
class BadConfigError(ValueError):
3841
_msg = "bad Config: %r"
@@ -56,6 +59,14 @@ def reset(self) -> None:
5659
self.clear()
5760
self.refresh()
5861

62+
def enable_gpu(self) -> ConfigSet:
63+
"""
64+
Configure Zarr to use GPUs where possible.
65+
"""
66+
return self.set(
67+
{"buffer": "zarr.core.buffer.gpu.Buffer", "ndbuffer": "zarr.core.buffer.gpu.NDBuffer"}
68+
)
69+
5970

6071
# The default configuration for zarr
6172
config = Config(

0 commit comments

Comments
 (0)