66if TYPE_CHECKING :
77 from typing import Self
88
9- import numpy .typing as npt
10-
119 from zarr .core .buffer import Buffer , BufferPrototype
1210 from zarr .core .chunk_grids import ChunkGrid
1311 from zarr .core .common import JSON , ChunkCoords
2018
2119import numcodecs .abc
2220import numpy as np
21+ import numpy .typing as npt
2322
2423from zarr .abc .codec import ArrayArrayCodec , ArrayBytesCodec , BytesBytesCodec , Codec
2524from zarr .core .array_spec import ArraySpec
@@ -152,7 +151,7 @@ def _replace_special_floats(obj: object) -> Any:
152151@dataclass (frozen = True , kw_only = True )
153152class ArrayV3Metadata (ArrayMetadata ):
154153 shape : ChunkCoords
155- data_type : np . dtype [ Any ]
154+ data_type : DataType
156155 chunk_grid : ChunkGrid
157156 chunk_key_encoding : ChunkKeyEncoding
158157 fill_value : Any
@@ -167,7 +166,7 @@ def __init__(
167166 self ,
168167 * ,
169168 shape : Iterable [int ],
170- data_type : npt .DTypeLike ,
169+ data_type : npt .DTypeLike | DataType ,
171170 chunk_grid : dict [str , JSON ] | ChunkGrid ,
172171 chunk_key_encoding : dict [str , JSON ] | ChunkKeyEncoding ,
173172 fill_value : Any ,
@@ -180,18 +179,18 @@ def __init__(
180179 Because the class is a frozen dataclass, we set attributes using object.__setattr__
181180 """
182181 shape_parsed = parse_shapelike (shape )
183- data_type_parsed = parse_dtype (data_type )
182+ data_type_parsed = DataType . parse (data_type )
184183 chunk_grid_parsed = ChunkGrid .from_dict (chunk_grid )
185184 chunk_key_encoding_parsed = ChunkKeyEncoding .from_dict (chunk_key_encoding )
186185 dimension_names_parsed = parse_dimension_names (dimension_names )
187- fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed )
186+ fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed . to_numpy_dtype () )
188187 attributes_parsed = parse_attributes (attributes )
189188 codecs_parsed_partial = parse_codecs (codecs )
190189 storage_transformers_parsed = parse_storage_transformers (storage_transformers )
191190
192191 array_spec = ArraySpec (
193192 shape = shape_parsed ,
194- dtype = data_type_parsed ,
193+ dtype = data_type_parsed . to_numpy_dtype () ,
195194 fill_value = fill_value_parsed ,
196195 order = "C" , # TODO: order is not needed here.
197196 prototype = default_buffer_prototype (), # TODO: prototype is not needed here.
@@ -224,11 +223,14 @@ def _validate_metadata(self) -> None:
224223 if self .fill_value is None :
225224 raise ValueError ("`fill_value` is required." )
226225 for codec in self .codecs :
227- codec .validate (shape = self .shape , dtype = self .data_type , chunk_grid = self .chunk_grid )
226+ codec .validate (
227+ shape = self .shape , dtype = self .data_type .to_numpy_dtype (), chunk_grid = self .chunk_grid
228+ )
228229
229230 @property
230231 def dtype (self ) -> np .dtype [Any ]:
231- return self .data_type
232+ """Interpret Zarr dtype as NumPy dtype"""
233+ return self .data_type .to_numpy_dtype ()
232234
233235 @property
234236 def ndim (self ) -> int :
@@ -266,13 +268,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
266268 _ = parse_node_type_array (_data .pop ("node_type" ))
267269
268270 # check that the data_type attribute is valid
269- _ = DataType (_data [ "data_type" ] )
271+ data_type = DataType . parse (_data . pop ( "data_type" ) )
270272
271273 # dimension_names key is optional, normalize missing to `None`
272274 _data ["dimension_names" ] = _data .pop ("dimension_names" , None )
273275 # attributes key is optional, normalize missing to `None`
274276 _data ["attributes" ] = _data .pop ("attributes" , None )
275- return cls (** _data ) # type: ignore[arg-type]
277+ return cls (** _data , data_type = data_type ) # type: ignore[arg-type]
276278
277279 def to_dict (self ) -> dict [str , JSON ]:
278280 out_dict = super ().to_dict ()
@@ -490,8 +492,11 @@ def to_numpy_shortname(self) -> str:
490492 }
491493 return data_type_to_numpy [self ]
492494
495+ def to_numpy_dtype (self ) -> np .dtype [Any ]:
496+ return np .dtype (self .to_numpy_shortname ())
497+
493498 @classmethod
494- def from_dtype (cls , dtype : np .dtype [Any ]) -> DataType :
499+ def from_numpy_dtype (cls , dtype : np .dtype [Any ]) -> DataType :
495500 dtype_to_data_type = {
496501 "|b1" : "bool" ,
497502 "bool" : "bool" ,
@@ -511,16 +516,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
511516 }
512517 return DataType [dtype_to_data_type [dtype .str ]]
513518
514-
515- def parse_dtype (data : npt .DTypeLike ) -> np .dtype [Any ]:
516- try :
517- dtype = np .dtype (data )
518- except (ValueError , TypeError ) as e :
519- raise ValueError (f"Invalid V3 data_type: { data } " ) from e
520- # check that this is a valid v3 data_type
521- try :
522- _ = DataType .from_dtype (dtype )
523- except KeyError as e :
524- raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
525-
526- return dtype
519+ @classmethod
520+ def parse (cls , dtype : None | DataType | Any ) -> DataType :
521+ if dtype is None :
522+ # the default dtype
523+ return DataType .float64
524+ if isinstance (dtype , DataType ):
525+ return dtype
526+ else :
527+ try :
528+ dtype = np .dtype (dtype )
529+ except (ValueError , TypeError ) as e :
530+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
531+ # check that this is a valid v3 data_type
532+ try :
533+ data_type = DataType .from_numpy_dtype (dtype )
534+ except KeyError as e :
535+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
536+ return data_type
0 commit comments