Skip to content

Commit d74e7a4

Browse files
committed
new dtypes
1 parent 3c50f54 commit d74e7a4

File tree

11 files changed

+703
-31
lines changed

11 files changed

+703
-31
lines changed

src/zarr/core/_info.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
99
from zarr.core.common import ZarrFormat
10-
from zarr.core.metadata.v3 import DataType
10+
from zarr.core.metadata.dtype import BaseDataType
11+
12+
# from zarr.core.metadata.v3 import DataType
1113

1214

1315
@dataclasses.dataclass(kw_only=True)
@@ -78,7 +80,7 @@ class ArrayInfo:
7880

7981
_type: Literal["Array"] = "Array"
8082
_zarr_format: ZarrFormat
81-
_data_type: np.dtype[Any] | DataType
83+
_data_type: np.dtype[Any] | BaseDataType
8284
_shape: tuple[int, ...]
8385
_shard_shape: tuple[int, ...] | None = None
8486
_chunk_shape: tuple[int, ...] | None = None

src/zarr/core/array.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,21 @@
9898
ArrayV3MetadataDict,
9999
T_ArrayMetadata,
100100
)
101+
from zarr.core.metadata.dtype import BaseDataType
101102
from zarr.core.metadata.v2 import (
102103
_default_compressor,
103104
_default_filters,
104105
parse_compressor,
105106
parse_filters,
106107
)
107-
from zarr.core.metadata.v3 import DataType, parse_node_type_array
108+
from zarr.core.metadata.v3 import parse_node_type_array
108109
from zarr.core.sync import sync
109110
from zarr.errors import MetadataValidationError
110111
from zarr.registry import (
111112
_parse_array_array_codec,
112113
_parse_array_bytes_codec,
113114
_parse_bytes_bytes_codec,
115+
get_data_type_from_numpy,
114116
get_pipeline_class,
115117
)
116118
from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path
@@ -1682,7 +1684,7 @@ async def info_complete(self) -> Any:
16821684
def _info(
16831685
self, count_chunks_initialized: int | None = None, count_bytes_stored: int | None = None
16841686
) -> Any:
1685-
_data_type: np.dtype[Any] | DataType
1687+
_data_type: np.dtype[Any] | BaseDataType
16861688
if isinstance(self.metadata, ArrayV2Metadata):
16871689
_data_type = self.metadata.dtype
16881690
else:
@@ -4203,17 +4205,11 @@ def _get_default_chunk_encoding_v3(
42034205
"""
42044206
Get the default ArrayArrayCodecs, ArrayBytesCodec, and BytesBytesCodec for a given dtype.
42054207
"""
4206-
dtype = DataType.from_numpy(np_dtype)
4207-
if dtype == DataType.string:
4208-
dtype_key = "string"
4209-
elif dtype == DataType.bytes:
4210-
dtype_key = "bytes"
4211-
else:
4212-
dtype_key = "numeric"
4208+
dtype = get_data_type_from_numpy(np_dtype)
42134209

4214-
default_filters = zarr_config.get("array.v3_default_filters").get(dtype_key)
4215-
default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype_key)
4216-
default_compressors = zarr_config.get("array.v3_default_compressors").get(dtype_key)
4210+
default_filters = zarr_config.get("array.v3_default_filters").get(dtype.type)
4211+
default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype.type)
4212+
default_compressors = zarr_config.get("array.v3_default_compressors").get(dtype.type)
42174213

42184214
filters = tuple(_parse_array_array_codec(codec_dict) for codec_dict in default_filters)
42194215
serializer = _parse_array_bytes_codec(default_serializer)

src/zarr/core/dtype/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from zarr.core.dtype.core import ZarrDType
2+
3+
__all__ = ["ZarrDType"]

src/zarr/core/dtype/core.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
# Overview
3+
4+
This module provides a proof-of-concept standalone interface for managing dtypes in the zarr-python codebase.
5+
6+
The `ZarrDType` class introduced in this module effectively acts as a replacement for `np.dtype` throughout the
7+
zarr-python codebase. It attempts to encapsulate all relevant runtime information necessary for working with
8+
dtypes in the context of the Zarr V3 specification (e.g. is this a core dtype or not, how many bytes and what
9+
endianness is the dtype etc). By providing this abstraction, the module aims to:
10+
11+
- Simplify dtype management within zarr-python
12+
- Support runtime flexibility and custom extensions
13+
- Remove unnecessary dependencies on the numpy API
14+
15+
## Extensibility
16+
17+
The module attempts to support user-driven extensions, allowing developers to introduce custom dtypes
18+
without requiring immediate changes to zarr-python. Extensions can leverage the current entrypoint mechanism,
19+
enabling integration of experimental features. Over time, widely adopted extensions may be formalized through
20+
inclusion in zarr-python or standardized via a Zarr Enhancement Proposal (ZEP), but this is not essential.
21+
22+
## Examples
23+
24+
### Core `dtype` Registration
25+
26+
The following example demonstrates how to register a built-in `dtype` in the core codebase:
27+
28+
```python
29+
from zarr.core.dtype import ZarrDType
30+
from zarr.registry import register_v3dtype
31+
32+
class Float16(ZarrDType):
33+
zarr_spec_format = "3"
34+
experimental = False
35+
endianness = "little"
36+
byte_count = 2
37+
to_numpy = np.dtype('float16')
38+
39+
register_v3dtype(Float16)
40+
```
41+
42+
### Entrypoint Extension
43+
44+
The following example demonstrates how users can register a new `bfloat16` dtype for Zarr.
45+
This approach adheres to the existing Zarr entrypoint pattern as much as possible, ensuring
46+
consistency with other extensions. The code below would typically be part of a Python package
47+
that specifies the entrypoints for the extension:
48+
49+
```python
50+
import ml_dtypes
51+
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype
52+
53+
class Bfloat16(ZarrDType):
54+
zarr_spec_format = "3"
55+
experimental = True
56+
endianness = "little"
57+
byte_count = 2
58+
to_numpy = np.dtype('bfloat16') # Enabled by importing ml_dtypes
59+
configuration_v3 = {
60+
"version": "example_value",
61+
"author": "example_value",
62+
"ml_dtypes_version": "example_value"
63+
}
64+
```
65+
66+
### dtype lookup
67+
68+
The following examples demonstrate how to perform a lookup for the relevant ZarrDType, given
69+
a string that matches the dtype Zarr specification ID, or a numpy dtype object:
70+
71+
```
72+
from zarr.registry import get_v3dtype_class, get_v3dtype_class_from_numpy
73+
74+
get_v3dtype_class('complex64') # returns little-endian Complex64 ZarrDType
75+
get_v3dtype_class('not_registered_dtype') # ValueError
76+
77+
get_v3dtype_class_from_numpy('>i2') # returns big-endian Int16 ZarrDType
78+
get_v3dtype_class_from_numpy(np.dtype('float32')) # returns little-endian Float32 ZarrDType
79+
get_v3dtype_class_from_numpy('i10') # ValueError
80+
```
81+
82+
### String dtypes
83+
84+
The following indicates one possibility for supporting variable-length strings. It is via the
85+
entrypoint mechanism as in a previous example. The Apache Arrow specification does not currently
86+
include a dtype for fixed-length strings (only for fixed-length bytes) and so I am using string
87+
here to implicitly refer to a variable-length string data (there may be some subtleties with codecs
88+
that means this needs to be refined further):
89+
90+
```python
91+
import numpy as np
92+
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype
93+
94+
try:
95+
to_numpy = np.dtypes.StringDType()
96+
except AttributeError:
97+
to_numpy = np.dtypes.ObjectDType()
98+
99+
class String(ZarrDType):
100+
zarr_spec_format = "3"
101+
experimental = True
102+
endianness = 'little'
103+
byte_count = None # None is defined to mean variable
104+
to_numpy = to_numpy
105+
```
106+
107+
### int4 dtype
108+
109+
There is currently considerable interest in the AI community in 'quantising' models - storing
110+
models at reduced precision, while minimising loss of information content. There are a number
111+
of sub-byte dtypes that the community are using e.g. int4. Unfortunately numpy does not
112+
currently have support for handling such sub-byte dtypes in an easy way. However, they can
113+
still be held in a numpy array and then passed (in a zero-copy way) to something like pytorch
114+
which can handle appropriately:
115+
116+
```python
117+
import numpy as np
118+
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype
119+
120+
class Int4(ZarrDType):
121+
zarr_spec_format = "3"
122+
experimental = True
123+
endianness = 'little'
124+
byte_count = 1 # this is ugly, but I could change this from byte_count to bit_count if there was consensus
125+
to_numpy = np.dtype('B') # could also be np.dtype('V1'), but this would prevent bit-twiddling
126+
configuration_v3 = {
127+
"version": "example_value",
128+
"author": "example_value",
129+
}
130+
```
131+
"""
132+
133+
from __future__ import annotations
134+
135+
from typing import Any, Literal
136+
137+
import numpy as np
138+
139+
140+
class FrozenClassVariables(type):
141+
def __setattr__(cls, attr: str, value: object) -> None:
142+
if hasattr(cls, attr):
143+
raise ValueError(f"Attribute {attr} on ZarrDType class can not be changed once set.")
144+
else:
145+
raise AttributeError(f"'{cls}' object has no attribute '{attr}'")
146+
147+
148+
class ZarrDType(metaclass=FrozenClassVariables):
149+
zarr_spec_format: Literal["2", "3"] # the version of the zarr spec used
150+
experimental: bool # is this in the core spec or not
151+
endianness: Literal[
152+
"big", "little", None
153+
] # None indicates not defined i.e. single byte or byte strings
154+
byte_count: int | None # None indicates variable count
155+
to_numpy: np.dtype[Any] # may involve installing a a numpy extension e.g. ml_dtypes;
156+
157+
configuration_v3: dict | None # TODO: understand better how this is recommended by the spec
158+
159+
_zarr_spec_identifier: str # implementation detail used to map to core spec
160+
161+
def __init_subclass__( # enforces all required fields are set and basic sanity checks
162+
cls,
163+
**kwargs,
164+
) -> None:
165+
required_attrs = [
166+
"zarr_spec_format",
167+
"experimental",
168+
"endianness",
169+
"byte_count",
170+
"to_numpy",
171+
]
172+
for attr in required_attrs:
173+
if not hasattr(cls, attr):
174+
raise ValueError(f"{attr} is a required attribute for a Zarr dtype.")
175+
176+
if not hasattr(cls, "configuration_v3"):
177+
cls.configuration_v3 = None
178+
179+
cls._zarr_spec_identifier = (
180+
"big_" + cls.__qualname__.lower()
181+
if cls.endianness == "big"
182+
else cls.__qualname__.lower()
183+
) # how this dtype is identified in core spec; convention is prefix with big_ for big-endian
184+
185+
cls._validate() # sanity check on basic requirements
186+
187+
super().__init_subclass__(**kwargs)
188+
189+
# TODO: add further checks
190+
@classmethod
191+
def _validate(cls):
192+
if cls.byte_count is not None and cls.byte_count <= 0:
193+
raise ValueError("byte_count must be a positive integer.")
194+
195+
if cls.byte_count == 1 and cls.endianness is not None:
196+
raise ValueError("Endianness must be None for single-byte types.")

0 commit comments

Comments
 (0)