Skip to content

Commit 4ceb6ed

Browse files
committed
refactor: use inheritance to remove boilerplate in dtype definitions
1 parent 60b2e9d commit 4ceb6ed

File tree

8 files changed

+575
-617
lines changed

8 files changed

+575
-617
lines changed

src/zarr/core/dtype/_numpy.py

Lines changed: 379 additions & 521 deletions
Large diffs are not rendered by default.

src/zarr/core/dtype/common.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def check_json_bool(data: JSON) -> TypeGuard[bool]:
3131
Bool
3232
True if the data is a boolean, False otherwise.
3333
"""
34-
return bool(isinstance(data, bool))
34+
return isinstance(data, bool)
3535

3636

3737
def check_json_str(data: JSON) -> TypeGuard[str]:
@@ -293,7 +293,7 @@ def complex_to_json_v3(data: complex | np.complexfloating[Any, Any]) -> tuple[JS
293293
return float_to_json_v3(data.real), float_to_json_v3(data.imag)
294294

295295

296-
def complex_to_json(
296+
def complex_float_to_json(
297297
data: complex | np.complexfloating[Any, Any], zarr_format: ZarrFormat
298298
) -> tuple[JSONFloat, JSONFloat]:
299299
"""
@@ -424,60 +424,48 @@ def float_from_json(data: JSONFloat, zarr_format: ZarrFormat) -> float:
424424
return float_from_json_v3(data)
425425

426426

427-
def complex_from_json_v2(
428-
data: tuple[JSONFloat, JSONFloat], dtype: np.dtypes.Complex64DType | np.dtypes.Complex128DType
429-
) -> np.complexfloating[Any, Any]:
427+
def complex_float_from_json_v2(data: tuple[JSONFloat, JSONFloat]) -> complex:
430428
"""
431429
Convert a JSON complex float to a complex number (v2).
432430
433431
Parameters
434432
----------
435433
data : tuple[JSONFloat, JSONFloat]
436434
The JSON complex float to convert.
437-
dtype : Any
438-
The numpy dtype.
439435
440436
Returns
441437
-------
442438
np.complexfloating
443439
The complex number.
444440
"""
445-
return dtype.type(complex(float_from_json_v2(data[0]), float_from_json_v2(data[1])))
441+
return complex(float_from_json_v2(data[0]), float_from_json_v2(data[1]))
446442

447443

448-
def complex_from_json_v3(
449-
data: tuple[JSONFloat, JSONFloat], dtype: np.dtypes.Complex64DType | np.dtypes.Complex128DType
450-
) -> np.complexfloating[Any, Any]:
444+
def complex_float_from_json_v3(data: tuple[JSONFloat, JSONFloat]) -> complex:
451445
"""
452446
Convert a JSON complex float to a complex number (v3).
453447
454448
Parameters
455449
----------
456450
data : tuple[JSONFloat, JSONFloat]
457451
The JSON complex float to convert.
458-
dtype : Any
459-
The numpy dtype.
460452
461453
Returns
462454
-------
463455
np.complexfloating
464456
The complex number.
465457
"""
466-
return dtype.type(complex(float_from_json_v3(data[0]), float_from_json_v3(data[1])))
458+
return complex(float_from_json_v3(data[0]), float_from_json_v3(data[1]))
467459

468460

469-
def complex_from_json(
470-
data: tuple[JSONFloat, JSONFloat], dtype: Any, zarr_format: ZarrFormat
471-
) -> np.complexfloating[Any, Any]:
461+
def complex_float_from_json(data: tuple[JSONFloat, JSONFloat], zarr_format: ZarrFormat) -> complex:
472462
"""
473463
Convert a JSON complex float to a complex number based on zarr format.
474464
475465
Parameters
476466
----------
477467
data : tuple[JSONFloat, JSONFloat]
478468
The JSON complex float to convert.
479-
dtype : Any
480-
The numpy dtype.
481469
zarr_format : ZarrFormat
482470
The zarr format version.
483471
@@ -487,12 +475,9 @@ def complex_from_json(
487475
The complex number.
488476
"""
489477
if zarr_format == 2:
490-
return complex_from_json_v2(data, dtype)
478+
return complex_float_from_json_v2(data)
491479
else:
492-
if check_json_complex_float_v3(data):
493-
return complex_from_json_v3(data, dtype)
494-
else:
495-
raise TypeError(f"Invalid type: {data}. Expected a sequence of two numbers.")
480+
return complex_float_from_json_v3(data)
496481
raise ValueError(f"Invalid zarr format: {zarr_format}. Expected 2 or 3.")
497482

498483

src/zarr/core/dtype/wrapper.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
# This is the bound for the dtypes that we support. If we support non-numpy dtypes,
1818
# then this bound will need to be widened.
1919
_BaseDType = np.dtype[np.generic]
20-
TScalar = TypeVar("TScalar", bound=_BaseScalar)
20+
TScalar_co = TypeVar("TScalar_co", bound=_BaseScalar, covariant=True)
2121
# TODO: figure out an interface or protocol that non-numpy dtypes can use
22-
TDType = TypeVar("TDType", bound=_BaseDType)
22+
# These two type parameters are covariant because we want isinstance(ZDType[Subclass](), ZDType[BaseDType]) to be True
23+
TDType_co = TypeVar("TDType_co", bound=_BaseDType, covariant=True)
2324

2425

2526
@dataclass(frozen=True, kw_only=True, slots=True)
26-
class ZDType(Generic[TDType, TScalar], ABC):
27+
class ZDType(Generic[TDType_co, TScalar_co], ABC):
2728
"""
2829
Abstract base class for wrapping native array data types, e.g. numpy dtypes
2930
@@ -41,11 +42,11 @@ class ZDType(Generic[TDType, TScalar], ABC):
4142
# mypy currently disallows class variables to contain type parameters
4243
# but it seems OK for us to use it here:
4344
# https://github.com/python/typing/discussions/1424#discussioncomment-7989934
44-
dtype_cls: ClassVar[type[TDType]] # type: ignore[misc]
45+
dtype_cls: ClassVar[type[TDType_co]] # type: ignore[misc]
4546
_zarr_v3_name: ClassVar[str]
4647

4748
@classmethod
48-
def check_dtype(cls: type[Self], dtype: _BaseDType) -> TypeGuard[TDType]:
49+
def check_dtype(cls: type[Self], dtype: _BaseDType) -> TypeGuard[TDType_co]:
4950
"""
5051
Check that a data type matches the dtype_cls class attribute. Used as a type guard.
5152
@@ -89,7 +90,7 @@ def from_dtype(cls: type[Self], dtype: _BaseDType) -> Self:
8990

9091
@classmethod
9192
@abstractmethod
92-
def _from_dtype_unsafe(cls: type[Self], dtype: TDType) -> Self:
93+
def _from_dtype_unsafe(cls: type[Self], dtype: _BaseDType) -> Self:
9394
"""
9495
Wrap a native dtype without checking.
9596
@@ -106,7 +107,7 @@ def _from_dtype_unsafe(cls: type[Self], dtype: TDType) -> Self:
106107
...
107108

108109
@abstractmethod
109-
def to_dtype(self: Self) -> TDType:
110+
def to_dtype(self: Self) -> TDType_co:
110111
"""
111112
Return an instance of the wrapped dtype.
112113
@@ -117,8 +118,61 @@ def to_dtype(self: Self) -> TDType:
117118
"""
118119
...
119120

121+
def cast_value(self, data: object) -> TScalar_co:
122+
"""
123+
Cast a value to the wrapped scalar type. The type is first checked for compatibility. If it's
124+
incompatible with the associated scalar type, a ``TypeError`` will be raised.
125+
126+
Parameters
127+
----------
128+
data : TScalar
129+
The scalar value to cast.
130+
131+
Returns
132+
-------
133+
TScalar
134+
The cast value.
135+
"""
136+
if self.check_value(data):
137+
return self._cast_value_unsafe(data)
138+
raise TypeError(f"Invalid value: {data}")
139+
140+
@abstractmethod
141+
def check_value(self, data: object) -> bool:
142+
"""
143+
Check that a value is a valid value for the wrapped data type.
144+
145+
Parameters
146+
----------
147+
data : object
148+
A value to check.
149+
150+
Returns
151+
-------
152+
Bool
153+
True if the value is valid, False otherwise.
154+
"""
155+
...
156+
157+
@abstractmethod
158+
def _cast_value_unsafe(self, data: object) -> TScalar_co:
159+
"""
160+
Cast a value to the wrapped data type. This method should not perform any input validation.
161+
162+
Parameters
163+
----------
164+
data : TScalar
165+
The scalar value to cast.
166+
167+
Returns
168+
-------
169+
TScalar
170+
The cast value.
171+
"""
172+
...
173+
120174
@abstractmethod
121-
def default_value(self) -> TScalar:
175+
def default_value(self) -> TScalar_co:
122176
"""
123177
Get the default value for the wrapped data type. This is a method, rather than an attribute,
124178
because the default value for some data types may depend on parameters that are not known
@@ -216,7 +270,7 @@ def _from_json_unsafe(cls: type[Self], data: JSON, zarr_format: ZarrFormat) -> S
216270
...
217271

218272
@abstractmethod
219-
def to_json_value(self, data: TScalar, *, zarr_format: ZarrFormat) -> JSON:
273+
def to_json_value(self, data: object, *, zarr_format: ZarrFormat) -> JSON:
220274
"""
221275
Convert a single value to JSON-serializable format.
222276
@@ -235,7 +289,7 @@ def to_json_value(self, data: TScalar, *, zarr_format: ZarrFormat) -> JSON:
235289
...
236290

237291
@abstractmethod
238-
def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
292+
def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar_co:
239293
"""
240294
Read a JSON-serializable value as a scalar.
241295

src/zarr/core/metadata/v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from zarr.abc.metadata import Metadata
1111
from zarr.core.dtype import get_data_type_from_native_dtype
12-
from zarr.core.dtype.wrapper import TDType, TScalar, ZDType, _BaseDType, _BaseScalar
12+
from zarr.core.dtype.wrapper import TDType_co, TScalar_co, ZDType, _BaseDType, _BaseScalar
1313

1414
if TYPE_CHECKING:
1515
from typing import Any, Literal, Self
@@ -58,7 +58,7 @@ def __init__(
5858
self,
5959
*,
6060
shape: ChunkCoords,
61-
dtype: ZDType[TDType, TScalar],
61+
dtype: ZDType[TDType_co, TScalar_co],
6262
chunks: ChunkCoords,
6363
fill_value: Any,
6464
order: MemoryOrder,
@@ -176,7 +176,7 @@ def to_dict(self) -> dict[str, JSON]:
176176
zarray_dict["filters"] = new_filters
177177

178178
if self.fill_value is not None:
179-
fill_value = self.dtype.to_json_value(self.fill_value, zarr_format=2) # type: ignore[arg-type]
179+
fill_value = self.dtype.to_json_value(self.fill_value, zarr_format=2)
180180
zarray_dict["fill_value"] = fill_value
181181

182182
zarray_dict["dtype"] = self.dtype.to_json(zarr_format=2)

tests/conftest.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition
2121
from zarr.core.common import JSON, parse_shapelike
2222
from zarr.core.config import config as zarr_config
23-
from zarr.core.dtype import get_data_type_from_native_dtype
23+
from zarr.core.dtype import data_type_registry, get_data_type_from_native_dtype
24+
from zarr.core.dtype._numpy import DateTime64, HasLength, Structured
2425
from zarr.core.metadata.v2 import ArrayV2Metadata
2526
from zarr.core.metadata.v3 import ArrayV3Metadata
2627
from zarr.core.sync import sync
@@ -36,6 +37,7 @@
3637
from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike
3738
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
3839
from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat
40+
from zarr.core.dtype.wrapper import ZDType
3941

4042

4143
async def parse_store(
@@ -404,3 +406,17 @@ def meta_from_array(
404406
chunk_key_encoding=chunk_key_encoding,
405407
dimension_names=dimension_names,
406408
)
409+
410+
411+
# Generate a collection of zdtype instances for use in testing.
412+
zdtype_examples: tuple[ZDType[Any, Any], ...] = ()
413+
for wrapper_cls in data_type_registry.contents.values():
414+
# The Structured dtype has to be constructed with some actual fields
415+
if wrapper_cls is Structured:
416+
zdtype_examples += (wrapper_cls.from_dtype(np.dtype([("a", np.float64), ("b", np.int8)])),)
417+
elif issubclass(wrapper_cls, HasLength):
418+
zdtype_examples += (wrapper_cls(length=1),)
419+
elif issubclass(wrapper_cls, DateTime64):
420+
zdtype_examples += (wrapper_cls(unit="s"),)
421+
else:
422+
zdtype_examples += (wrapper_cls(),)

tests/test_array.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,15 @@
4040
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
4141
from zarr.core.dtype import get_data_type_from_native_dtype
4242
from zarr.core.dtype._numpy import Float64
43+
from zarr.core.dtype.wrapper import ZDType
4344
from zarr.core.group import AsyncGroup
4445
from zarr.core.indexing import BasicIndexer, ceildiv
4546
from zarr.core.sync import sync
4647
from zarr.errors import ContainsArrayError, ContainsGroupError
4748
from zarr.storage import LocalStore, MemoryStore, StorePath
4849

50+
from .conftest import zdtype_examples
51+
4952
if TYPE_CHECKING:
5053
from zarr.core.array_spec import ArrayConfigLike
5154
from zarr.core.metadata.v2 import ArrayV2Metadata
@@ -177,32 +180,42 @@ def test_array_name_properties_with_group(
177180

178181
@pytest.mark.parametrize("store", ["memory"], indirect=True)
179182
@pytest.mark.parametrize("specifiy_fill_value", [True, False])
180-
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "complex64"])
181-
def test_array_v3_fill_value_default(
182-
store: MemoryStore, specifiy_fill_value: bool, dtype_str: str
183+
@pytest.mark.parametrize(
184+
"zdtype", zdtype_examples, ids=tuple(str(type(v)) for v in zdtype_examples)
185+
)
186+
def test_array_fill_value_default(
187+
store: MemoryStore, specifiy_fill_value: bool, zdtype: ZDType[Any, Any]
183188
) -> None:
184189
"""
185190
Test that creating an array with the fill_value parameter set to None, or unspecified,
186191
results in the expected fill_value attribute of the array, i.e. 0 cast to the array's dtype.
187192
"""
188193
shape = (10,)
189-
default_fill_value = 0
190194
if specifiy_fill_value:
191195
arr = zarr.create_array(
192196
store=store,
193197
shape=shape,
194-
dtype=dtype_str,
198+
dtype=zdtype,
195199
zarr_format=3,
196200
chunks=shape,
197201
fill_value=None,
198202
)
199203
else:
200-
arr = zarr.create_array(
201-
store=store, shape=shape, dtype=dtype_str, zarr_format=3, chunks=shape
202-
)
204+
arr = zarr.create_array(store=store, shape=shape, dtype=zdtype, zarr_format=3, chunks=shape)
205+
expected_fill_value = zdtype.default_value()
206+
if isinstance(expected_fill_value, np.datetime64 | np.timedelta64):
207+
if np.isnat(expected_fill_value):
208+
assert np.isnat(arr.fill_value)
209+
elif isinstance(expected_fill_value, np.floating | np.complexfloating):
210+
if np.isnan(expected_fill_value):
211+
assert np.isnan(arr.fill_value)
212+
else:
213+
assert arr.fill_value == expected_fill_value
214+
# A simpler check would be to ensure that arr.fill_value.dtype == arr.dtype
215+
# But for some numpy data types (namely, U), scalars might not have length. An empty string
216+
# scalar from a `>U4` array would have dtype `>U`, and arr.fill_value.dtype == arr.dtype will fail.
203217

204-
assert arr.fill_value == np.dtype(dtype_str).type(default_fill_value)
205-
assert arr.fill_value.dtype == arr.dtype
218+
assert type(arr.fill_value) is type(np.array([arr.fill_value], dtype=arr.dtype)[0])
206219

207220

208221
@pytest.mark.parametrize("store", ["memory"], indirect=True)
@@ -1004,7 +1017,7 @@ async def test_v3_chunk_encoding(
10041017
filters=filters,
10051018
compressors=compressors,
10061019
serializer="auto",
1007-
dtype=arr.metadata.data_type, # type: ignore[union-attr]
1020+
dtype=arr._zdtype,
10081021
)
10091022
assert arr.filters == filters_expected
10101023
assert arr.compressors == compressors_expected
@@ -1369,4 +1382,4 @@ async def test_sharding_coordinate_selection() -> None:
13691382
shards=(2, 4, 4),
13701383
)
13711384
arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4))
1372-
assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index]
1385+
assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all()

0 commit comments

Comments
 (0)