Skip to content

Commit d6608e7

Browse files
author
asubramaniam
committed
First working version of Zstd codec on the GPU
1 parent e8bfb64 commit d6608e7

File tree

4 files changed

+219
-6
lines changed

4 files changed

+219
-6
lines changed

src/zarr/codecs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
44
from zarr.codecs.bytes import BytesCodec, Endian
55
from zarr.codecs.crc32c_ import Crc32cCodec
6+
from zarr.codecs.gpu import NvcompZstdCodec
67
from zarr.codecs.gzip import GzipCodec
78
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
89
from zarr.codecs.transpose import TransposeCodec
@@ -17,6 +18,7 @@
1718
"Crc32cCodec",
1819
"Endian",
1920
"GzipCodec",
21+
"NvcompZstdCodec",
2022
"ShardingCodec",
2123
"ShardingCodecIndexLocation",
2224
"TransposeCodec",

src/zarr/codecs/gpu.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Awaitable
5+
from dataclasses import dataclass
6+
from functools import cached_property
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
11+
from zarr.abc.codec import BytesBytesCodec
12+
from zarr.core.common import JSON, parse_named_configuration
13+
from zarr.registry import register_codec
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Generator, Iterable
17+
from typing import Any, Self
18+
19+
from zarr.core.array_spec import ArraySpec
20+
from zarr.core.buffer import Buffer
21+
22+
try:
23+
import cupy as cp
24+
except ImportError:
25+
cp = None
26+
27+
try:
28+
from nvidia import nvcomp
29+
except ImportError:
30+
nvcomp = None
31+
32+
33+
class AsyncCUDAEvent(Awaitable[None]):
34+
"""An awaitable wrapper around a CuPy CUDA event for asynchronous waiting."""
35+
36+
def __init__(
37+
self, event: cp.cuda.Event, initial_delay: float = 0.001, max_delay: float = 0.1
38+
) -> None:
39+
"""
40+
Initialize the async CUDA event.
41+
42+
Args:
43+
event (cp.cuda.Event): The CuPy CUDA event to wait on.
44+
initial_delay (float): Initial polling delay in seconds (default: 0.001s).
45+
max_delay (float): Maximum polling delay in seconds (default: 0.1s).
46+
"""
47+
self.event = event
48+
self.initial_delay = initial_delay
49+
self.max_delay = max_delay
50+
51+
def __await__(self) -> Generator[Any, None, None]:
52+
"""Makes the event awaitable by yielding control until the event is complete."""
53+
return self._wait().__await__()
54+
55+
async def _wait(self) -> None:
56+
"""Polls the CUDA event asynchronously with exponential backoff until it completes."""
57+
delay = self.initial_delay
58+
while not self.event.query(): # `query()` returns True if the event is complete
59+
await asyncio.sleep(delay) # Yield control to other async tasks
60+
delay = min(delay * 2, self.max_delay) # Exponential backoff
61+
62+
63+
def parse_zstd_level(data: JSON) -> int:
64+
if isinstance(data, int):
65+
if data >= 23:
66+
raise ValueError(f"Value must be less than or equal to 22. Got {data} instead.")
67+
return data
68+
raise TypeError(f"Got value with type {type(data)}, but expected an int.")
69+
70+
71+
def parse_checksum(data: JSON) -> bool:
72+
if isinstance(data, bool):
73+
return data
74+
raise TypeError(f"Expected bool. Got {type(data)}.")
75+
76+
77+
@dataclass(frozen=True)
78+
class NvcompZstdCodec(BytesBytesCodec):
79+
is_fixed_size = True
80+
81+
level: int = 0
82+
checksum: bool = False
83+
84+
def __init__(self, *, level: int = 0, checksum: bool = False) -> None:
85+
# TODO: Set CUDA device appropriately here and also set CUDA stream
86+
87+
level_parsed = parse_zstd_level(level)
88+
checksum_parsed = parse_checksum(checksum)
89+
90+
object.__setattr__(self, "level", level_parsed)
91+
object.__setattr__(self, "checksum", checksum_parsed)
92+
93+
@classmethod
94+
def from_dict(cls, data: dict[str, JSON]) -> Self:
95+
_, configuration_parsed = parse_named_configuration(data, "zstd")
96+
return cls(**configuration_parsed) # type: ignore[arg-type]
97+
98+
def to_dict(self) -> dict[str, JSON]:
99+
return {
100+
"name": "zstd",
101+
"configuration": {"level": self.level, "checksum": self.checksum},
102+
}
103+
104+
@cached_property
105+
def _zstd_codec(self) -> nvcomp.Codec:
106+
# config_dict = {algorithm = "Zstd", "level": self.level, "checksum": self.checksum}
107+
# return Zstd.from_config(config_dict)
108+
device = cp.cuda.Device() # Select the current default device
109+
stream = cp.cuda.get_current_stream() # Use the current default stream
110+
return nvcomp.Codec(
111+
algorithm="Zstd",
112+
bitstream_kind=nvcomp.BitstreamKind.RAW,
113+
device_id=device.id,
114+
cuda_stream=stream.ptr,
115+
)
116+
117+
async def _convert_to_nvcomp_arrays(
118+
self,
119+
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
120+
) -> tuple[list[nvcomp.Array], list[int]]:
121+
none_indices = [i for i, (b, _) in enumerate(chunks_and_specs) if b is None]
122+
filtered_inputs = [b.as_array_like() for b, _ in chunks_and_specs if b is not None]
123+
# TODO: add CUDA stream here
124+
return nvcomp.as_arrays(filtered_inputs), none_indices
125+
126+
async def _convert_from_nvcomp_arrays(
127+
self,
128+
arrays: Iterable[nvcomp.Array],
129+
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
130+
) -> Iterable[Buffer | None]:
131+
return [
132+
spec.prototype.buffer.from_array_like(cp.asarray(a, dtype=np.dtype("b"))) if a else None
133+
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True)
134+
]
135+
136+
async def decode(
137+
self,
138+
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
139+
) -> Iterable[Buffer | None]:
140+
"""Decodes a batch of chunks.
141+
Chunks can be None in which case they are ignored by the codec.
142+
143+
Parameters
144+
----------
145+
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]]
146+
Ordered set of encoded chunks with their accompanying chunk spec.
147+
148+
Returns
149+
-------
150+
Iterable[Buffer | None]
151+
"""
152+
chunks_and_specs = list(chunks_and_specs)
153+
154+
# Convert to nvcomp arrays
155+
filtered_inputs, none_indices = await self._convert_to_nvcomp_arrays(chunks_and_specs)
156+
157+
outputs = self._zstd_codec.decode(filtered_inputs) if len(filtered_inputs) > 0 else []
158+
for index in none_indices:
159+
outputs.insert(index, None)
160+
161+
return await self._convert_from_nvcomp_arrays(outputs, chunks_and_specs)
162+
163+
async def encode(
164+
self,
165+
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
166+
) -> Iterable[Buffer | None]:
167+
"""Encodes a batch of chunks.
168+
Chunks can be None in which case they are ignored by the codec.
169+
170+
Parameters
171+
----------
172+
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]]
173+
Ordered set of to-be-encoded chunks with their accompanying chunk spec.
174+
175+
Returns
176+
-------
177+
Iterable[Buffer | None]
178+
"""
179+
# TODO: Make this actually async
180+
chunks_and_specs = list(chunks_and_specs)
181+
182+
# Convert to nvcomp arrays
183+
filtered_inputs, none_indices = await self._convert_to_nvcomp_arrays(chunks_and_specs)
184+
185+
outputs = self._zstd_codec.encode(filtered_inputs) if len(filtered_inputs) > 0 else []
186+
for index in none_indices:
187+
outputs.insert(index, None)
188+
189+
return await self._convert_from_nvcomp_arrays(outputs, chunks_and_specs)
190+
191+
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:
192+
raise NotImplementedError
193+
194+
195+
register_codec("zstd", NvcompZstdCodec)

src/zarr/core/config.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def enable_gpu(self) -> ConfigSet:
6464
Configure Zarr to use GPUs where possible.
6565
"""
6666
return self.set(
67-
{"buffer": "zarr.core.buffer.gpu.Buffer", "ndbuffer": "zarr.core.buffer.gpu.NDBuffer"}
67+
{
68+
"buffer": "zarr.core.buffer.gpu.Buffer",
69+
"ndbuffer": "zarr.core.buffer.gpu.NDBuffer",
70+
"codecs": {"zstd": "zarr.codecs.gpu.NvcompZstdCodec"},
71+
}
6872
)
6973

7074

@@ -96,13 +100,22 @@ def enable_gpu(self) -> ConfigSet:
96100
},
97101
"v3_default_compressors": {
98102
"numeric": [
99-
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
103+
{
104+
"name": "zstd",
105+
"configuration": {"level": 0, "checksum": False},
106+
},
100107
],
101108
"string": [
102-
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
109+
{
110+
"name": "zstd",
111+
"configuration": {"level": 0, "checksum": False},
112+
},
103113
],
104114
"bytes": [
105-
{"name": "zstd", "configuration": {"level": 0, "checksum": False}},
115+
{
116+
"name": "zstd",
117+
"configuration": {"level": 0, "checksum": False},
118+
},
106119
],
107120
},
108121
},

tests/test_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import zarr.api.asynchronous
1111
import zarr.core.group
1212
from zarr import Array, Group
13+
from zarr.abc.codec import Codec
1314
from zarr.abc.store import Store
1415
from zarr.api.synchronous import (
1516
create,
@@ -23,6 +24,7 @@
2324
save_array,
2425
save_group,
2526
)
27+
from zarr.codecs import NvcompZstdCodec
2628
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
2729
from zarr.errors import MetadataValidationError
2830
from zarr.storage import MemoryStore
@@ -1131,15 +1133,16 @@ def test_open_array_with_mode_r_plus(store: Store) -> None:
11311133
indirect=True,
11321134
)
11331135
@pytest.mark.parametrize("zarr_format", [None, 2, 3])
1134-
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None) -> None:
1136+
@pytest.mark.parametrize("codec", ["auto", NvcompZstdCodec()])
1137+
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None, codec: str | Codec) -> None:
11351138
import cupy as cp
11361139

11371140
if zarr_format == 2:
11381141
# Without this, the zstd codec attempts to convert the cupy
11391142
# array to bytes.
11401143
compressors = None
11411144
else:
1142-
compressors = "auto"
1145+
compressors = codec
11431146

11441147
with zarr.config.enable_gpu():
11451148
src = cp.random.uniform(size=(100, 100)) # allocate on the device

0 commit comments

Comments
 (0)