1616 TypeVar ,
1717 cast ,
1818 get_args ,
19+ get_origin ,
1920)
2021
2122import numpy as np
@@ -133,7 +134,7 @@ def float_to_json_v2(data: float | np.floating[Any]) -> JSONFloat:
133134
134135def float_to_json_v3 (data : float | np .floating [Any ]) -> JSONFloat :
135136 # v3 can in principle handle distinct NaN values, but numpy does not represent these explicitly
136- # so we just re-use the v2 routine here
137+ # so we just reuse the v2 routine here
137138 return float_to_json_v2 (data )
138139
139140
@@ -148,11 +149,11 @@ def float_to_json(data: float | np.floating[Any], zarr_format: ZarrFormat) -> JS
148149 raise ValueError (f"Invalid zarr format: { zarr_format } . Expected 2 or 3." )
149150
150151
151- def complex_to_json_v2 (data : complex | np .complexfloating [Any ]) -> tuple [JSONFloat , JSONFloat ]:
152+ def complex_to_json_v2 (data : complex | np .complexfloating [Any , Any ]) -> tuple [JSONFloat , JSONFloat ]:
152153 return float_to_json_v2 (data .real ), float_to_json_v2 (data .imag )
153154
154155
155- def complex_to_json_v3 (data : complex | np .complexfloating [Any ]) -> tuple [JSONFloat , JSONFloat ]:
156+ def complex_to_json_v3 (data : complex | np .complexfloating [Any , Any ]) -> tuple [JSONFloat , JSONFloat ]:
156157 return float_to_json_v3 (data .real ), float_to_json_v3 (data .imag )
157158
158159
@@ -226,15 +227,16 @@ def complex_from_json(
226227
227228
228229def datetime_to_json (data : np .datetime64 [Any ]) -> int :
229- return data .view ("int" ).item ()
230+ return data .view (np . int64 ).item ()
230231
231232
232233def datetime_from_json (data : int , unit : DateUnit | TimeUnit ) -> np .datetime64 [Any ]:
233234 return np .int64 (data ).view (f"datetime64[{ unit } ]" )
234235
235236
237+ TScalar = TypeVar ("TScalar" , bound = np .generic | str , covariant = True )
238+ # TODO: figure out an interface or protocol that non-numpy dtypes can
236239TDType = TypeVar ("TDType" , bound = np .dtype [Any ])
237- TScalar = TypeVar ("TScalar" , bound = np .generic | str )
238240
239241
240242@dataclass (frozen = True , kw_only = True )
@@ -244,17 +246,27 @@ class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
244246 endianness : Endianness | None = "native"
245247
246248 def __init_subclass__ (cls ) -> None :
247- # Subclasses will bind the first generic type parameter to an attribute of the class
248249 # TODO: wrap this in some *very informative* error handling
249250 generic_args = get_args (get_original_bases (cls )[0 ])
250- cls .dtype_cls = generic_args [0 ]
251+ # the logic here is that if a subclass was created with generic type parameters
252+ # specified explicitly, then we bind that type parameter to the dtype_cls attribute
253+ if len (generic_args ) > 0 :
254+ cls .dtype_cls = generic_args [0 ]
255+ else :
256+ # but if the subclass was created without generic type parameters specified explicitly,
257+ # then we check the parent DTypeWrapper classes and retrieve their generic type parameters
258+ for base in cls .__orig_bases__ :
259+ if get_origin (base ) is DTypeWrapper :
260+ generic_args = get_args (base )
261+ cls .dtype_cls = generic_args [0 ]
262+ break
251263 return super ().__init_subclass__ ()
252264
253265 def to_dict (self ) -> dict [str , JSON ]:
254266 return {"name" : self .name }
255267
256268 def cast_value (self : Self , value : object ) -> TScalar :
257- return cast (np . generic , self .unwrap ().type (value ))
269+ return cast (TScalar , self .unwrap ().type (value ))
258270
259271 @abstractmethod
260272 def default_value (self ) -> TScalar : ...
@@ -455,7 +467,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex1
455467@dataclass (frozen = True , kw_only = True )
456468class FlexibleWrapperBase (DTypeWrapper [TDType , TScalar ]):
457469 item_size_bits : ClassVar [int ]
458- length : int
470+ length : int = 0
459471
460472 @classmethod
461473 def _wrap_unsafe (cls , dtype : TDType ) -> Self :
@@ -467,7 +479,7 @@ def unwrap(self) -> TDType:
467479
468480
469481@dataclass (frozen = True , kw_only = True )
470- class StaticByteString (FlexibleWrapperBase [np .dtypes .BytesDType , np .bytes_ ]):
482+ class FixedLengthAsciiString (FlexibleWrapperBase [np .dtypes .BytesDType , np .bytes_ ]):
471483 name = "numpy/static_byte_string"
472484 item_size_bits = 8
473485
@@ -492,11 +504,18 @@ class StaticRawBytes(FlexibleWrapperBase[np.dtypes.VoidDType, np.void]):
492504 item_size_bits = 8
493505
494506 def default_value (self ) -> np .void :
495- return np . void ( b"" )
507+ return self . cast_value (( " \x00 " * self . length ). encode ( "ascii" ) )
496508
497509 def to_dict (self ) -> dict [str , JSON ]:
498510 return {"name" : f"r{ self .length * self .item_size_bits } " }
499511
512+ @classmethod
513+ def check_dtype (cls : type [Self ], dtype : TDType ) -> TypeGuard [TDType ]:
514+ """
515+ Reject structured dtypes by ensuring that dtype.fields is None
516+ """
517+ return type (dtype ) is cls .dtype_cls and dtype .fields is None
518+
500519 def unwrap (self ) -> np .dtypes .VoidDType :
501520 # this needs to be overridden because numpy does not allow creating a void type
502521 # by invoking np.dtypes.VoidDType directly
@@ -512,7 +531,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
512531
513532
514533@dataclass (frozen = True , kw_only = True )
515- class StaticUnicodeString (FlexibleWrapperBase [np .dtypes .StrDType , np .str_ ]):
534+ class FixedLengthUnicodeString (FlexibleWrapperBase [np .dtypes .StrDType , np .str_ ]):
516535 name = "numpy/static_unicode_string"
517536 item_size_bits = 32 # UCS4 is 32 bits per code point
518537
@@ -599,7 +618,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
599618@dataclass (frozen = True , kw_only = True )
600619class DateTime64 (DTypeWrapper [np .dtypes .DateTime64DType , np .datetime64 ]):
601620 name = "numpy/datetime64"
602- unit : DateUnit | TimeUnit
621+ unit : DateUnit | TimeUnit = "s"
603622
604623 def default_value (self ) -> np .datetime64 :
605624 return np .datetime64 ("NaT" )
@@ -609,6 +628,9 @@ def _wrap_unsafe(cls, dtype: np.dtypes.DateTime64DType) -> Self:
609628 unit = dtype .name [dtype .name .rfind ("[" ) + 1 : dtype .name .rfind ("]" )]
610629 return cls (unit = unit )
611630
631+ def cast_value (self , value : object ) -> np .datetime64 :
632+ return self .unwrap ().type (value , self .unit )
633+
612634 def unwrap (self ) -> np .dtypes .DateTime64DType :
613635 return np .dtype (f"datetime64[{ self .unit } ]" ).newbyteorder (
614636 endianness_to_numpy_str (self .endianness )
@@ -651,6 +673,26 @@ def _wrap_unsafe(cls, dtype: np.dtypes.VoidDType) -> Self:
651673
652674 return cls (fields = tuple (fields ))
653675
676+ def to_dict (self ) -> dict [str , JSON ]:
677+ base_dict = super ().to_dict ()
678+ if base_dict .get ("configuration" , {}) != {}:
679+ raise ValueError (
680+ "This data type wrapper cannot inherit from a data type wrapper that defines a configuration for its dict serialization"
681+ )
682+ field_configs = [
683+ (f_name , f_dtype .to_dict (), f_offset ) for f_name , f_dtype , f_offset in self .fields
684+ ]
685+ base_dict ["configuration" ] = {"fields" : field_configs }
686+ return base_dict
687+
688+ @classmethod
689+ def from_dict (cls , data : dict [str , JSON ]) -> Self :
690+ fields = tuple (
691+ (f_name , get_data_type_from_dict (f_dtype ), f_offset )
692+ for f_name , f_dtype , f_offset in data ["fields" ]
693+ )
694+ return cls (fields = fields )
695+
654696 def unwrap (self ) -> np .dtypes .VoidDType :
655697 return np .dtype ([(key , dtype .unwrap ()) for (key , dtype , _ ) in self .fields ])
656698
@@ -665,7 +707,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
665707 return np .array ([as_bytes ], dtype = dtype .str ).view (dtype )[0 ]
666708
667709
668- def get_data_type_from_numpy (dtype : npt .DTypeLike ) -> DTypeWrapper :
710+ def get_data_type_from_numpy (dtype : npt .DTypeLike ) -> DTypeWrapper [ Any , Any ] :
669711 if dtype in (str , "str" ):
670712 if _NUMPY_SUPPORTS_VLEN_STRING :
671713 np_dtype = np .dtype ("T" )
@@ -674,17 +716,10 @@ def get_data_type_from_numpy(dtype: npt.DTypeLike) -> DTypeWrapper:
674716 else :
675717 np_dtype = np .dtype (dtype )
676718 data_type_registry .lazy_load ()
677- for val in data_type_registry .contents .values ():
678- try :
679- return val .wrap (np_dtype )
680- except TypeError :
681- pass
682- raise ValueError (
683- f"numpy dtype '{ dtype } ' does not have a corresponding Zarr dtype in: { list (data_type_registry .contents )} ."
684- )
719+ return data_type_registry .match_dtype (np_dtype )
685720
686721
687- def get_data_type_from_dict (dtype : dict [str , JSON ]) -> DTypeWrapper :
722+ def get_data_type_from_dict (dtype : dict [str , JSON ]) -> DTypeWrapper [ Any . Any ] :
688723 data_type_registry .lazy_load ()
689724 dtype_name = dtype ["name" ]
690725 dtype_cls = data_type_registry .get (dtype_name )
@@ -737,14 +772,14 @@ def register(self: Self, cls: type[DTypeWrapper[Any, Any]]) -> None:
737772 def get (self , key : str ) -> type [DTypeWrapper [Any , Any ]]:
738773 return self .contents [key ]
739774
740- def match_dtype (self , dtype : npt . DTypeLike ) -> DTypeWrapper [Any , Any ]:
775+ def match_dtype (self , dtype : TDType ) -> DTypeWrapper [Any , Any ]:
741776 self .lazy_load ()
742777 for val in self .contents .values ():
743778 try :
744779 return val .wrap (dtype )
745780 except TypeError :
746781 pass
747- raise ValueError (f"No data type wrapper found that matches { dtype } " )
782+ raise ValueError (f"No data type wrapper found that matches dtype ' { dtype } ' " )
748783
749784
750785def register_data_type (cls : type [DTypeWrapper [Any , Any ]]) -> None :
@@ -756,7 +791,7 @@ def register_data_type(cls: type[DTypeWrapper[Any, Any]]) -> None:
756791INTEGER_DTYPE = Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
757792FLOAT_DTYPE = Float16 | Float32 | Float64
758793COMPLEX_DTYPE = Complex64 | Complex128
759- STRING_DTYPE = StaticUnicodeString | VariableLengthString | StaticByteString
794+ STRING_DTYPE = FixedLengthUnicodeString | VariableLengthString | FixedLengthAsciiString
760795DTYPE = (
761796 Bool
762797 | INTEGER_DTYPE
0 commit comments