66if TYPE_CHECKING :
77 from typing import Self
88
9- import numpy .typing as npt
10-
119 from zarr .core .buffer import Buffer , BufferPrototype
1210 from zarr .core .chunk_grids import ChunkGrid
1311 from zarr .core .common import JSON , ChunkCoords
2018
2119import numcodecs .abc
2220import numpy as np
21+ import numpy .typing as npt
2322
2423from zarr .abc .codec import ArrayArrayCodec , ArrayBytesCodec , BytesBytesCodec , Codec
2524from zarr .core .array_spec import ArraySpec
3130from zarr .core .metadata .common import ArrayMetadata , parse_attributes
3231from zarr .registry import get_codec_class
3332
33+ DEFAULT_DTYPE = "float64"
34+
3435
3536def parse_zarr_format (data : object ) -> Literal [3 ]:
3637 if data == 3 :
@@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any:
152153@dataclass (frozen = True , kw_only = True )
153154class ArrayV3Metadata (ArrayMetadata ):
154155 shape : ChunkCoords
155- data_type : np . dtype [ Any ]
156+ data_type : DataType
156157 chunk_grid : ChunkGrid
157158 chunk_key_encoding : ChunkKeyEncoding
158159 fill_value : Any
@@ -167,7 +168,7 @@ def __init__(
167168 self ,
168169 * ,
169170 shape : Iterable [int ],
170- data_type : npt .DTypeLike ,
171+ data_type : npt .DTypeLike | DataType ,
171172 chunk_grid : dict [str , JSON ] | ChunkGrid ,
172173 chunk_key_encoding : dict [str , JSON ] | ChunkKeyEncoding ,
173174 fill_value : Any ,
@@ -180,18 +181,18 @@ def __init__(
180181 Because the class is a frozen dataclass, we set attributes using object.__setattr__
181182 """
182183 shape_parsed = parse_shapelike (shape )
183- data_type_parsed = parse_dtype (data_type )
184+ data_type_parsed = DataType . parse (data_type )
184185 chunk_grid_parsed = ChunkGrid .from_dict (chunk_grid )
185186 chunk_key_encoding_parsed = ChunkKeyEncoding .from_dict (chunk_key_encoding )
186187 dimension_names_parsed = parse_dimension_names (dimension_names )
187- fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed )
188+ fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed . to_numpy () )
188189 attributes_parsed = parse_attributes (attributes )
189190 codecs_parsed_partial = parse_codecs (codecs )
190191 storage_transformers_parsed = parse_storage_transformers (storage_transformers )
191192
192193 array_spec = ArraySpec (
193194 shape = shape_parsed ,
194- dtype = data_type_parsed ,
195+ dtype = data_type_parsed . to_numpy () ,
195196 fill_value = fill_value_parsed ,
196197 order = "C" , # TODO: order is not needed here.
197198 prototype = default_buffer_prototype (), # TODO: prototype is not needed here.
@@ -224,11 +225,14 @@ def _validate_metadata(self) -> None:
224225 if self .fill_value is None :
225226 raise ValueError ("`fill_value` is required." )
226227 for codec in self .codecs :
227- codec .validate (shape = self .shape , dtype = self .data_type , chunk_grid = self .chunk_grid )
228+ codec .validate (
229+ shape = self .shape , dtype = self .data_type .to_numpy (), chunk_grid = self .chunk_grid
230+ )
228231
229232 @property
230233 def dtype (self ) -> np .dtype [Any ]:
231- return self .data_type
234+ """Interpret Zarr dtype as NumPy dtype"""
235+ return self .data_type .to_numpy ()
232236
233237 @property
234238 def ndim (self ) -> int :
@@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
266270 _ = parse_node_type_array (_data .pop ("node_type" ))
267271
268272 # check that the data_type attribute is valid
269- _ = DataType (_data [ "data_type" ] )
273+ data_type = DataType . parse (_data . pop ( "data_type" ) )
270274
271275 # dimension_names key is optional, normalize missing to `None`
272276 _data ["dimension_names" ] = _data .pop ("dimension_names" , None )
273277 # attributes key is optional, normalize missing to `None`
274278 _data ["attributes" ] = _data .pop ("attributes" , None )
275- return cls (** _data ) # type: ignore[arg-type]
279+ return cls (** _data , data_type = data_type ) # type: ignore[arg-type]
276280
277281 def to_dict (self ) -> dict [str , JSON ]:
278282 out_dict = super ().to_dict ()
@@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str:
490494 }
491495 return data_type_to_numpy [self ]
492496
497+ def to_numpy (self ) -> np .dtype [Any ]:
498+ return np .dtype (self .to_numpy_shortname ())
499+
493500 @classmethod
494- def from_dtype (cls , dtype : np .dtype [Any ]) -> DataType :
501+ def from_numpy (cls , dtype : np .dtype [Any ]) -> DataType :
495502 dtype_to_data_type = {
496503 "|b1" : "bool" ,
497504 "bool" : "bool" ,
@@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
511518 }
512519 return DataType [dtype_to_data_type [dtype .str ]]
513520
514-
515- def parse_dtype (data : npt .DTypeLike ) -> np .dtype [Any ]:
516- try :
517- dtype = np .dtype (data )
518- except (ValueError , TypeError ) as e :
519- raise ValueError (f"Invalid V3 data_type: { data } " ) from e
520- # check that this is a valid v3 data_type
521- try :
522- _ = DataType .from_dtype (dtype )
523- except KeyError as e :
524- raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
525-
526- return dtype
521+ @classmethod
522+ def parse (cls , dtype : None | DataType | Any ) -> DataType :
523+ if dtype is None :
524+ # the default dtype
525+ return DataType [DEFAULT_DTYPE ]
526+ if isinstance (dtype , DataType ):
527+ return dtype
528+ else :
529+ try :
530+ dtype = np .dtype (dtype )
531+ except (ValueError , TypeError ) as e :
532+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
533+ # check that this is a valid v3 data_type
534+ try :
535+ data_type = DataType .from_numpy (dtype )
536+ except KeyError as e :
537+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
538+ return data_type
0 commit comments