88if TYPE_CHECKING :
99 from typing import Self
1010
11- import numpy .typing as npt
12-
1311 from zarr .core .buffer import Buffer , BufferPrototype
1412 from zarr .core .chunk_grids import ChunkGrid
1513 from zarr .core .common import JSON , ChunkCoords
2220
2321import numcodecs .abc
2422import numpy as np
23+ import numpy .typing as npt
2524
2625from zarr .abc .codec import ArrayArrayCodec , ArrayBytesCodec , BytesBytesCodec , Codec
2726from zarr .core .array_spec import ArraySpec
3837from zarr .core .metadata .common import ArrayMetadata , parse_attributes
3938from zarr .registry import get_codec_class
4039
40+ DEFAULT_DTYPE = "float64"
41+
4142
4243def parse_zarr_format (data : object ) -> Literal [3 ]:
4344 if data == 3 :
@@ -159,7 +160,7 @@ def _replace_special_floats(obj: object) -> Any:
159160@dataclass (frozen = True , kw_only = True )
160161class ArrayV3Metadata (ArrayMetadata ):
161162 shape : ChunkCoords
162- data_type : np . dtype [ Any ]
163+ data_type : DataType
163164 chunk_grid : ChunkGrid
164165 chunk_key_encoding : ChunkKeyEncoding
165166 fill_value : Any
@@ -174,7 +175,7 @@ def __init__(
174175 self ,
175176 * ,
176177 shape : Iterable [int ],
177- data_type : npt .DTypeLike ,
178+ data_type : npt .DTypeLike | DataType ,
178179 chunk_grid : dict [str , JSON ] | ChunkGrid ,
179180 chunk_key_encoding : dict [str , JSON ] | ChunkKeyEncoding ,
180181 fill_value : Any ,
@@ -187,18 +188,18 @@ def __init__(
187188 Because the class is a frozen dataclass, we set attributes using object.__setattr__
188189 """
189190 shape_parsed = parse_shapelike (shape )
190- data_type_parsed = parse_dtype (data_type )
191+ data_type_parsed = DataType . parse (data_type )
191192 chunk_grid_parsed = ChunkGrid .from_dict (chunk_grid )
192193 chunk_key_encoding_parsed = ChunkKeyEncoding .from_dict (chunk_key_encoding )
193194 dimension_names_parsed = parse_dimension_names (dimension_names )
194- fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed )
195+ fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed . to_numpy () )
195196 attributes_parsed = parse_attributes (attributes )
196197 codecs_parsed_partial = parse_codecs (codecs )
197198 storage_transformers_parsed = parse_storage_transformers (storage_transformers )
198199
199200 array_spec = ArraySpec (
200201 shape = shape_parsed ,
201- dtype = data_type_parsed ,
202+ dtype = data_type_parsed . to_numpy () ,
202203 fill_value = fill_value_parsed ,
203204 order = "C" , # TODO: order is not needed here.
204205 prototype = default_buffer_prototype (), # TODO: prototype is not needed here.
@@ -231,11 +232,14 @@ def _validate_metadata(self) -> None:
231232 if self .fill_value is None :
232233 raise ValueError ("`fill_value` is required." )
233234 for codec in self .codecs :
234- codec .validate (shape = self .shape , dtype = self .data_type , chunk_grid = self .chunk_grid )
235+ codec .validate (
236+ shape = self .shape , dtype = self .data_type .to_numpy (), chunk_grid = self .chunk_grid
237+ )
235238
236239 @property
237240 def dtype (self ) -> np .dtype [Any ]:
238- return self .data_type
241+ """Interpret Zarr dtype as NumPy dtype"""
242+ return self .data_type .to_numpy ()
239243
240244 @property
241245 def ndim (self ) -> int :
@@ -273,13 +277,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
273277 _ = parse_node_type_array (_data .pop ("node_type" ))
274278
275279 # check that the data_type attribute is valid
276- _ = DataType (_data [ "data_type" ] )
280+ data_type = DataType . parse (_data . pop ( "data_type" ) )
277281
278282 # dimension_names key is optional, normalize missing to `None`
279283 _data ["dimension_names" ] = _data .pop ("dimension_names" , None )
280284 # attributes key is optional, normalize missing to `None`
281285 _data ["attributes" ] = _data .pop ("attributes" , None )
282- return cls (** _data ) # type: ignore[arg-type]
286+ return cls (** _data , data_type = data_type ) # type: ignore[arg-type]
283287
284288 def to_dict (self ) -> dict [str , JSON ]:
285289 out_dict = super ().to_dict ()
@@ -497,8 +501,11 @@ def to_numpy_shortname(self) -> str:
497501 }
498502 return data_type_to_numpy [self ]
499503
504+ def to_numpy (self ) -> np .dtype [Any ]:
505+ return np .dtype (self .to_numpy_shortname ())
506+
500507 @classmethod
501- def from_dtype (cls , dtype : np .dtype [Any ]) -> DataType :
508+ def from_numpy (cls , dtype : np .dtype [Any ]) -> DataType :
502509 dtype_to_data_type = {
503510 "|b1" : "bool" ,
504511 "bool" : "bool" ,
@@ -518,16 +525,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
518525 }
519526 return DataType [dtype_to_data_type [dtype .str ]]
520527
521-
522- def parse_dtype (data : npt .DTypeLike ) -> np .dtype [Any ]:
523- try :
524- dtype = np .dtype (data )
525- except (ValueError , TypeError ) as e :
526- raise ValueError (f"Invalid V3 data_type: { data } " ) from e
527- # check that this is a valid v3 data_type
528- try :
529- _ = DataType .from_dtype (dtype )
530- except KeyError as e :
531- raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
532-
533- return dtype
528+ @classmethod
529+ def parse (cls , dtype : None | DataType | Any ) -> DataType :
530+ if dtype is None :
531+ # the default dtype
532+ return DataType [DEFAULT_DTYPE ]
533+ if isinstance (dtype , DataType ):
534+ return dtype
535+ else :
536+ try :
537+ dtype = np .dtype (dtype )
538+ except (ValueError , TypeError ) as e :
539+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
540+ # check that this is a valid v3 data_type
541+ try :
542+ data_type = DataType .from_numpy (dtype )
543+ except KeyError as e :
544+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
545+ return data_type
0 commit comments