Skip to content

Commit e14279d

Browse files
committed
remove __post_init__ magic in favor of more explicit declaration
1 parent 6df84a9 commit e14279d

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

src/zarr/codecs/bytes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ async def _decode_single(
7171
chunk_spec: ArraySpec,
7272
) -> NDBuffer:
7373
assert isinstance(chunk_bytes, Buffer)
74-
75-
dtype = chunk_spec.dtype.with_endianness(self.endian).unwrap()
74+
# TODO: remove endianness enum in favor of literal union
75+
endian_str = self.endian.value if self.endian is not None else None
76+
dtype = chunk_spec.dtype.with_endianness(endian_str).unwrap()
7677

7778
as_array_like = chunk_bytes.as_array_like()
7879
if isinstance(as_array_like, NDArrayLike):

src/zarr/core/metadata/dtype.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
TypeVar,
1717
cast,
1818
get_args,
19-
get_origin,
2019
)
2120

2221
import numpy as np
2322
import numpy.typing as npt
24-
from typing_extensions import get_original_bases
2523

2624
from zarr.abc.metadata import Metadata
2725
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
@@ -245,23 +243,6 @@ class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
245243
dtype_cls: ClassVar[type[TDType]] # this class will create a numpy dtype
246244
endianness: Endianness | None = "native"
247245

248-
def __init_subclass__(cls) -> None:
249-
# TODO: wrap this in some *very informative* error handling
250-
generic_args = get_args(get_original_bases(cls)[0])
251-
# the logic here is that if a subclass was created with generic type parameters
252-
# specified explicitly, then we bind that type parameter to the dtype_cls attribute
253-
if len(generic_args) > 0:
254-
cls.dtype_cls = generic_args[0]
255-
else:
256-
# but if the subclass was created without generic type parameters specified explicitly,
257-
# then we check the parent DTypeWrapper classes and retrieve their generic type parameters
258-
for base in cls.__orig_bases__:
259-
if get_origin(base) is DTypeWrapper:
260-
generic_args = get_args(base)
261-
cls.dtype_cls = generic_args[0]
262-
break
263-
return super().__init_subclass__()
264-
265246
def to_dict(self) -> dict[str, JSON]:
266247
return {"name": self.name}
267248

@@ -314,6 +295,7 @@ def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScal
314295
@dataclass(frozen=True, kw_only=True)
315296
class Bool(DTypeWrapper[np.dtypes.BoolDType, np.bool_]):
316297
name = "bool"
298+
dtype_cls: ClassVar[type[np.dtypes.BoolDType]] = np.dtypes.BoolDType
317299

318300
def default_value(self) -> np.bool_:
319301
return np.False_
@@ -350,41 +332,49 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
350332

351333
@dataclass(frozen=True, kw_only=True)
352334
class Int8(IntWrapperBase[np.dtypes.Int8DType, np.int8]):
335+
dtype_cls = np.dtypes.Int8DType
353336
name = "int8"
354337

355338

356339
@dataclass(frozen=True, kw_only=True)
357340
class UInt8(IntWrapperBase[np.dtypes.UInt8DType, np.uint8]):
341+
dtype_cls = np.dtypes.UInt8DType
358342
name = "uint8"
359343

360344

361345
@dataclass(frozen=True, kw_only=True)
362346
class Int16(IntWrapperBase[np.dtypes.Int16DType, np.int16]):
347+
dtype_cls = np.dtypes.Int16DType
363348
name = "int16"
364349

365350

366351
@dataclass(frozen=True, kw_only=True)
367352
class UInt16(IntWrapperBase[np.dtypes.UInt16DType, np.uint16]):
353+
dtype_cls = np.dtypes.UInt16DType
368354
name = "uint16"
369355

370356

371357
@dataclass(frozen=True, kw_only=True)
372358
class Int32(IntWrapperBase[np.dtypes.Int32DType, np.int32]):
359+
dtype_cls = np.dtypes.Int32DType
373360
name = "int32"
374361

375362

376363
@dataclass(frozen=True, kw_only=True)
377364
class UInt32(IntWrapperBase[np.dtypes.UInt32DType, np.uint32]):
365+
dtype_cls = np.dtypes.UInt32DType
378366
name = "uint32"
379367

380368

381369
@dataclass(frozen=True, kw_only=True)
382370
class Int64(IntWrapperBase[np.dtypes.Int64DType, np.int64]):
371+
dtype_cls = np.dtypes.Int64DType
383372
name = "int64"
384373

385374

386375
@dataclass(frozen=True, kw_only=True)
387376
class UInt64(IntWrapperBase[np.dtypes.UInt64DType, np.uint64]):
377+
dtype_cls = np.dtypes.UInt64DType
388378
name = "uint64"
389379

390380

@@ -407,21 +397,25 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
407397

408398
@dataclass(frozen=True, kw_only=True)
409399
class Float16(FloatWrapperBase[np.dtypes.Float16DType, np.float16]):
400+
dtype_cls = np.dtypes.Float16DType
410401
name = "float16"
411402

412403

413404
@dataclass(frozen=True, kw_only=True)
414405
class Float32(FloatWrapperBase[np.dtypes.Float32DType, np.float32]):
406+
dtype_cls = np.dtypes.Float32DType
415407
name = "float32"
416408

417409

418410
@dataclass(frozen=True, kw_only=True)
419411
class Float64(FloatWrapperBase[np.dtypes.Float64DType, np.float64]):
412+
dtype_cls = np.dtypes.Float64DType
420413
name = "float64"
421414

422415

423416
@dataclass(frozen=True, kw_only=True)
424417
class Complex64(DTypeWrapper[np.dtypes.Complex64DType, np.complex64]):
418+
dtype_cls = np.dtypes.Complex64DType
425419
name = "complex64"
426420

427421
def default_value(self) -> np.complex64:
@@ -444,6 +438,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex6
444438

445439
@dataclass(frozen=True, kw_only=True)
446440
class Complex128(DTypeWrapper[np.dtypes.Complex128DType, np.complex128]):
441+
dtype_cls = np.dtypes.Complex128DType
447442
name = "complex128"
448443

449444
def default_value(self) -> np.complex128:
@@ -480,7 +475,8 @@ def unwrap(self) -> TDType:
480475

481476
@dataclass(frozen=True, kw_only=True)
482477
class FixedLengthAsciiString(FlexibleWrapperBase[np.dtypes.BytesDType, np.bytes_]):
483-
name = "numpy/static_byte_string"
478+
dtype_cls = np.dtypes.BytesDType
479+
name = "numpy.static_byte_string"
484480
item_size_bits = 8
485481

486482
def default_value(self) -> np.bytes_:
@@ -500,6 +496,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_:
500496

501497
@dataclass(frozen=True, kw_only=True)
502498
class StaticRawBytes(FlexibleWrapperBase[np.dtypes.VoidDType, np.void]):
499+
dtype_cls = np.dtypes.VoidDType
503500
name = "r*"
504501
item_size_bits = 8
505502

@@ -532,7 +529,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
532529

533530
@dataclass(frozen=True, kw_only=True)
534531
class FixedLengthUnicodeString(FlexibleWrapperBase[np.dtypes.StrDType, np.str_]):
535-
name = "numpy/static_unicode_string"
532+
dtype_cls = np.dtypes.StrDType
533+
name = "numpy.static_unicode_string"
536534
item_size_bits = 32 # UCS4 is 32 bits per code point
537535

538536
def default_value(self) -> np.str_:
@@ -554,7 +552,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_:
554552

555553
@dataclass(frozen=True, kw_only=True)
556554
class VariableLengthString(DTypeWrapper[np.dtypes.StringDType, str]):
557-
name = "numpy/vlen_string"
555+
dtype_cls = np.dtypes.StringDType
556+
name = "numpy.vlen_string"
558557

559558
def default_value(self) -> str:
560559
return ""
@@ -582,7 +581,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
582581

583582
@dataclass(frozen=True, kw_only=True)
584583
class VariableLengthString(DTypeWrapper[np.dtypes.ObjectDType, str]):
585-
name = "numpy/vlen_string"
584+
dtype_cls = np.dtypes.ObjectDType
585+
name = "numpy.vlen_string"
586586
endianness: Endianness = field(default=None)
587587

588588
def default_value(self) -> str:
@@ -617,6 +617,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
617617

618618
@dataclass(frozen=True, kw_only=True)
619619
class DateTime64(DTypeWrapper[np.dtypes.DateTime64DType, np.datetime64]):
620+
dtype_cls = np.dtypes.DateTime64DType
620621
name = "numpy/datetime64"
621622
unit: DateUnit | TimeUnit = "s"
622623

@@ -647,6 +648,7 @@ def to_json_value(self, data: np.datetime64, *, zarr_format: ZarrFormat) -> int:
647648

648649
@dataclass(frozen=True, kw_only=True)
649650
class Structured(DTypeWrapper[np.dtypes.VoidDType, np.void]):
651+
dtype_cls = np.dtypes.VoidDType
650652
name = "numpy/struct"
651653
fields: tuple[tuple[str, DTypeWrapper[Any, Any], int], ...]
652654

0 commit comments

Comments
 (0)