1616 TypeVar ,
1717 cast ,
1818 get_args ,
19- get_origin ,
2019)
2120
2221import numpy as np
2322import numpy .typing as npt
24- from typing_extensions import get_original_bases
2523
2624from zarr .abc .metadata import Metadata
2725from zarr .core .strings import _NUMPY_SUPPORTS_VLEN_STRING
@@ -245,23 +243,6 @@ class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
245243 dtype_cls : ClassVar [type [TDType ]] # this class will create a numpy dtype
246244 endianness : Endianness | None = "native"
247245
248- def __init_subclass__ (cls ) -> None :
249- # TODO: wrap this in some *very informative* error handling
250- generic_args = get_args (get_original_bases (cls )[0 ])
251- # the logic here is that if a subclass was created with generic type parameters
252- # specified explicitly, then we bind that type parameter to the dtype_cls attribute
253- if len (generic_args ) > 0 :
254- cls .dtype_cls = generic_args [0 ]
255- else :
256- # but if the subclass was created without generic type parameters specified explicitly,
257- # then we check the parent DTypeWrapper classes and retrieve their generic type parameters
258- for base in cls .__orig_bases__ :
259- if get_origin (base ) is DTypeWrapper :
260- generic_args = get_args (base )
261- cls .dtype_cls = generic_args [0 ]
262- break
263- return super ().__init_subclass__ ()
264-
265246 def to_dict (self ) -> dict [str , JSON ]:
266247 return {"name" : self .name }
267248
@@ -314,6 +295,7 @@ def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScal
314295@dataclass (frozen = True , kw_only = True )
315296class Bool (DTypeWrapper [np .dtypes .BoolDType , np .bool_ ]):
316297 name = "bool"
298+ dtype_cls : ClassVar [type [np .dtypes .BoolDType ]] = np .dtypes .BoolDType
317299
318300 def default_value (self ) -> np .bool_ :
319301 return np .False_
@@ -350,41 +332,49 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
350332
351333@dataclass (frozen = True , kw_only = True )
352334class Int8 (IntWrapperBase [np .dtypes .Int8DType , np .int8 ]):
335+ dtype_cls = np .dtypes .Int8DType
353336 name = "int8"
354337
355338
356339@dataclass (frozen = True , kw_only = True )
357340class UInt8 (IntWrapperBase [np .dtypes .UInt8DType , np .uint8 ]):
341+ dtype_cls = np .dtypes .UInt8DType
358342 name = "uint8"
359343
360344
361345@dataclass (frozen = True , kw_only = True )
362346class Int16 (IntWrapperBase [np .dtypes .Int16DType , np .int16 ]):
347+ dtype_cls = np .dtypes .Int16DType
363348 name = "int16"
364349
365350
366351@dataclass (frozen = True , kw_only = True )
367352class UInt16 (IntWrapperBase [np .dtypes .UInt16DType , np .uint16 ]):
353+ dtype_cls = np .dtypes .UInt16DType
368354 name = "uint16"
369355
370356
371357@dataclass (frozen = True , kw_only = True )
372358class Int32 (IntWrapperBase [np .dtypes .Int32DType , np .int32 ]):
359+ dtype_cls = np .dtypes .Int32DType
373360 name = "int32"
374361
375362
376363@dataclass (frozen = True , kw_only = True )
377364class UInt32 (IntWrapperBase [np .dtypes .UInt32DType , np .uint32 ]):
365+ dtype_cls = np .dtypes .UInt32DType
378366 name = "uint32"
379367
380368
381369@dataclass (frozen = True , kw_only = True )
382370class Int64 (IntWrapperBase [np .dtypes .Int64DType , np .int64 ]):
371+ dtype_cls = np .dtypes .Int64DType
383372 name = "int64"
384373
385374
386375@dataclass (frozen = True , kw_only = True )
387376class UInt64 (IntWrapperBase [np .dtypes .UInt64DType , np .uint64 ]):
377+ dtype_cls = np .dtypes .UInt64DType
388378 name = "uint64"
389379
390380
@@ -407,21 +397,25 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
407397
408398@dataclass (frozen = True , kw_only = True )
409399class Float16 (FloatWrapperBase [np .dtypes .Float16DType , np .float16 ]):
400+ dtype_cls = np .dtypes .Float16DType
410401 name = "float16"
411402
412403
413404@dataclass (frozen = True , kw_only = True )
414405class Float32 (FloatWrapperBase [np .dtypes .Float32DType , np .float32 ]):
406+ dtype_cls = np .dtypes .Float32DType
415407 name = "float32"
416408
417409
418410@dataclass (frozen = True , kw_only = True )
419411class Float64 (FloatWrapperBase [np .dtypes .Float64DType , np .float64 ]):
412+ dtype_cls = np .dtypes .Float64DType
420413 name = "float64"
421414
422415
423416@dataclass (frozen = True , kw_only = True )
424417class Complex64 (DTypeWrapper [np .dtypes .Complex64DType , np .complex64 ]):
418+ dtype_cls = np .dtypes .Complex64DType
425419 name = "complex64"
426420
427421 def default_value (self ) -> np .complex64 :
@@ -444,6 +438,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex6
444438
445439@dataclass (frozen = True , kw_only = True )
446440class Complex128 (DTypeWrapper [np .dtypes .Complex128DType , np .complex128 ]):
441+ dtype_cls = np .dtypes .Complex128DType
447442 name = "complex128"
448443
449444 def default_value (self ) -> np .complex128 :
@@ -480,7 +475,8 @@ def unwrap(self) -> TDType:
480475
481476@dataclass (frozen = True , kw_only = True )
482477class FixedLengthAsciiString (FlexibleWrapperBase [np .dtypes .BytesDType , np .bytes_ ]):
483- name = "numpy/static_byte_string"
478+ dtype_cls = np .dtypes .BytesDType
479+ name = "numpy.static_byte_string"
484480 item_size_bits = 8
485481
486482 def default_value (self ) -> np .bytes_ :
@@ -500,6 +496,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_:
500496
501497@dataclass (frozen = True , kw_only = True )
502498class StaticRawBytes (FlexibleWrapperBase [np .dtypes .VoidDType , np .void ]):
499+ dtype_cls = np .dtypes .VoidDType
503500 name = "r*"
504501 item_size_bits = 8
505502
@@ -532,7 +529,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
532529
533530@dataclass (frozen = True , kw_only = True )
534531class FixedLengthUnicodeString (FlexibleWrapperBase [np .dtypes .StrDType , np .str_ ]):
535- name = "numpy/static_unicode_string"
532+ dtype_cls = np .dtypes .StrDType
533+ name = "numpy.static_unicode_string"
536534 item_size_bits = 32 # UCS4 is 32 bits per code point
537535
538536 def default_value (self ) -> np .str_ :
@@ -554,7 +552,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_:
554552
555553 @dataclass (frozen = True , kw_only = True )
556554 class VariableLengthString (DTypeWrapper [np .dtypes .StringDType , str ]):
557- name = "numpy/vlen_string"
555+ dtype_cls = np .dtypes .StringDType
556+ name = "numpy.vlen_string"
558557
559558 def default_value (self ) -> str :
560559 return ""
@@ -582,7 +581,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
582581
583582 @dataclass (frozen = True , kw_only = True )
584583 class VariableLengthString (DTypeWrapper [np .dtypes .ObjectDType , str ]):
585- name = "numpy/vlen_string"
584+ dtype_cls = np .dtypes .ObjectDType
585+ name = "numpy.vlen_string"
586586 endianness : Endianness = field (default = None )
587587
588588 def default_value (self ) -> str :
@@ -617,6 +617,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
617617
618618@dataclass (frozen = True , kw_only = True )
619619class DateTime64 (DTypeWrapper [np .dtypes .DateTime64DType , np .datetime64 ]):
620+ dtype_cls = np .dtypes .DateTime64DType
620621 name = "numpy/datetime64"
621622 unit : DateUnit | TimeUnit = "s"
622623
@@ -647,6 +648,7 @@ def to_json_value(self, data: np.datetime64, *, zarr_format: ZarrFormat) -> int:
647648
648649@dataclass (frozen = True , kw_only = True )
649650class Structured (DTypeWrapper [np .dtypes .VoidDType , np .void ]):
651+ dtype_cls = np .dtypes .VoidDType
650652 name = "numpy/struct"
651653 fields : tuple [tuple [str , DTypeWrapper [Any , Any ], int ], ...]
652654
0 commit comments