Skip to content

Commit c1a8566

Browse files
committed
dtype-specific tests
1 parent bf24d69 commit c1a8566

File tree

2 files changed

+312
-52
lines changed

2 files changed

+312
-52
lines changed

src/zarr/core/metadata/dtype.py

Lines changed: 109 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def structured_scalar_to_json(data: bytes, zarr_format: ZarrFormat) -> str:
172172
raise NotImplementedError(f"Invalid zarr format: {zarr_format}. Expected 2.")
173173

174174

175-
def structured_scalar_from_json(data: JSON, zarr_format: ZarrFormat) -> bytes:
175+
def structured_scalar_from_json(data: str, zarr_format: ZarrFormat) -> bytes:
176176
if zarr_format == 2:
177177
return base64.b64decode(data.encode("ascii"))
178178
raise NotImplementedError(f"Invalid zarr format: {zarr_format}. Expected 2.")
@@ -202,11 +202,13 @@ def float_from_json(data: JSONFloat, zarr_format: ZarrFormat) -> float:
202202
return float_from_json_v3(data)
203203

204204

205-
def complex_from_json_v2(data: JSONFloat, dtype: Any) -> np.complexfloating:
206-
return dtype.type(data)
205+
def complex_from_json_v2(data: JSONFloat, dtype: Any) -> np.complexfloating[Any, Any]:
206+
return dtype.type(complex(*data))
207207

208208

209-
def complex_from_json_v3(data: tuple[JSONFloat, JSONFloat], dtype: Any) -> np.complexfloating:
209+
def complex_from_json_v3(
210+
data: tuple[JSONFloat, JSONFloat], dtype: Any
211+
) -> np.complexfloating[Any, Any]:
210212
return dtype.type(complex(*data))
211213

212214

@@ -223,6 +225,14 @@ def complex_from_json(
223225
raise ValueError(f"Invalid zarr format: {zarr_format}. Expected 2 or 3.")
224226

225227

228+
def datetime_to_json(data: np.datetime64[Any]) -> int:
229+
return data.view("int").item()
230+
231+
232+
def datetime_from_json(data: int, unit: DateUnit | TimeUnit) -> np.datetime64[Any]:
233+
return np.int64(data).view(f"datetime64[{unit}]")
234+
235+
226236
TDType = TypeVar("TDType", bound=np.dtype[Any])
227237
TScalar = TypeVar("TScalar", bound=np.generic | str)
228238

@@ -231,8 +241,6 @@ def complex_from_json(
231241
class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
232242
name: ClassVar[str]
233243
dtype_cls: ClassVar[type[TDType]] # this class will create a numpy dtype
234-
kind: ClassVar[DataTypeFlavor]
235-
default_value: ClassVar[TScalar]
236244
endianness: Endianness | None = "native"
237245

238246
def __init_subclass__(cls) -> None:
@@ -248,6 +256,9 @@ def to_dict(self) -> dict[str, JSON]:
248256
def cast_value(self: Self, value: object) -> TScalar:
249257
return cast(np.generic, self.unwrap().type(value))
250258

259+
@abstractmethod
260+
def default_value(self) -> TScalar: ...
261+
251262
@classmethod
252263
def check_dtype(cls: type[Self], dtype: TDType) -> TypeGuard[TDType]:
253264
"""
@@ -291,8 +302,9 @@ def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScal
291302
@dataclass(frozen=True, kw_only=True)
292303
class Bool(DTypeWrapper[np.dtypes.BoolDType, np.bool_]):
293304
name = "bool"
294-
kind = "boolean"
295-
default_value = np.False_
305+
306+
def default_value(self) -> np.bool_:
307+
return np.False_
296308

297309
@classmethod
298310
def _wrap_unsafe(cls, dtype: np.dtypes.BoolDType) -> Self:
@@ -308,7 +320,8 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bool_:
308320

309321

310322
class IntWrapperBase(DTypeWrapper[TDType, TScalar]):
311-
kind = "numeric"
323+
def default_value(self) -> TScalar:
324+
return self.unwrap().type(0)
312325

313326
@classmethod
314327
def _wrap_unsafe(cls, dtype: TDType) -> Self:
@@ -326,53 +339,46 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
326339
@dataclass(frozen=True, kw_only=True)
327340
class Int8(IntWrapperBase[np.dtypes.Int8DType, np.int8]):
328341
name = "int8"
329-
default_value = np.int8(0)
330342

331343

332344
@dataclass(frozen=True, kw_only=True)
333345
class UInt8(IntWrapperBase[np.dtypes.UInt8DType, np.uint8]):
334346
name = "uint8"
335-
default_value = np.uint8(0)
336347

337348

338349
@dataclass(frozen=True, kw_only=True)
339350
class Int16(IntWrapperBase[np.dtypes.Int16DType, np.int16]):
340351
name = "int16"
341-
default_value = np.int16(0)
342352

343353

344354
@dataclass(frozen=True, kw_only=True)
345355
class UInt16(IntWrapperBase[np.dtypes.UInt16DType, np.uint16]):
346356
name = "uint16"
347-
default_value = np.uint16(0)
348357

349358

350359
@dataclass(frozen=True, kw_only=True)
351360
class Int32(IntWrapperBase[np.dtypes.Int32DType, np.int32]):
352361
name = "int32"
353-
default_value = np.int32(0)
354362

355363

356364
@dataclass(frozen=True, kw_only=True)
357365
class UInt32(IntWrapperBase[np.dtypes.UInt32DType, np.uint32]):
358366
name = "uint32"
359-
default_value = np.uint32(0)
360367

361368

362369
@dataclass(frozen=True, kw_only=True)
363370
class Int64(IntWrapperBase[np.dtypes.Int64DType, np.int64]):
364371
name = "int64"
365-
default_value = np.int64(0)
366372

367373

368374
@dataclass(frozen=True, kw_only=True)
369375
class UInt64(IntWrapperBase[np.dtypes.UInt64DType, np.uint64]):
370376
name = "uint64"
371-
default_value = np.uint64(0)
372377

373378

374379
class FloatWrapperBase(DTypeWrapper[TDType, TScalar]):
375-
kind = "numeric"
380+
def default_value(self) -> TScalar:
381+
return self.unwrap().type(0.0)
376382

377383
@classmethod
378384
def _wrap_unsafe(cls, dtype: TDType) -> Self:
@@ -390,26 +396,24 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
390396
@dataclass(frozen=True, kw_only=True)
391397
class Float16(FloatWrapperBase[np.dtypes.Float16DType, np.float16]):
392398
name = "float16"
393-
default_value = np.float16(0)
394399

395400

396401
@dataclass(frozen=True, kw_only=True)
397402
class Float32(FloatWrapperBase[np.dtypes.Float32DType, np.float32]):
398403
name = "float32"
399-
default_value = np.float32(0)
400404

401405

402406
@dataclass(frozen=True, kw_only=True)
403407
class Float64(FloatWrapperBase[np.dtypes.Float64DType, np.float64]):
404408
name = "float64"
405-
default_value = np.float64(0)
406409

407410

408411
@dataclass(frozen=True, kw_only=True)
409412
class Complex64(DTypeWrapper[np.dtypes.Complex64DType, np.complex64]):
410413
name = "complex64"
411-
kind = "numeric"
412-
default_value = np.complex64(0)
414+
415+
def default_value(self) -> np.complex64:
416+
return np.complex64(0.0)
413417

414418
@classmethod
415419
def _wrap_unsafe(cls, dtype: np.dtypes.Complex64DType) -> Self:
@@ -429,8 +433,9 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex6
429433
@dataclass(frozen=True, kw_only=True)
430434
class Complex128(DTypeWrapper[np.dtypes.Complex128DType, np.complex128]):
431435
name = "complex128"
432-
kind = "numeric"
433-
default_value = np.complex128(0)
436+
437+
def default_value(self) -> np.complex128:
438+
return np.complex128(0.0)
434439

435440
@classmethod
436441
def _wrap_unsafe(cls, dtype: np.dtypes.Complex128DType) -> Self:
@@ -464,10 +469,11 @@ def unwrap(self) -> TDType:
464469
@dataclass(frozen=True, kw_only=True)
465470
class StaticByteString(FlexibleWrapperBase[np.dtypes.BytesDType, np.bytes_]):
466471
name = "numpy/static_byte_string"
467-
kind = "string"
468-
default_value = np.bytes_(0)
469472
item_size_bits = 8
470473

474+
def default_value(self) -> np.bytes_:
475+
return np.bytes_(b"")
476+
471477
def to_dict(self) -> dict[str, JSON]:
472478
return {"name": self.name, "configuration": {"length": self.length}}
473479

@@ -476,17 +482,18 @@ def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
476482

477483
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_:
478484
if check_json_str(data):
479-
return self.unwrap().type(data.encode("ascii"))
485+
return self.unwrap().type(base64.standard_b64decode(data.encode("ascii")))
480486
raise TypeError(f"Invalid type: {data}. Expected a string.")
481487

482488

483489
@dataclass(frozen=True, kw_only=True)
484490
class StaticRawBytes(FlexibleWrapperBase[np.dtypes.VoidDType, np.void]):
485491
name = "r*"
486-
kind = "bytes"
487-
default_value = np.void(b"")
488492
item_size_bits = 8
489493

494+
def default_value(self) -> np.void:
495+
return np.void(b"")
496+
490497
def to_dict(self) -> dict[str, JSON]:
491498
return {"name": f"r{self.length * self.item_size_bits}"}
492499

@@ -496,21 +503,22 @@ def unwrap(self) -> np.dtypes.VoidDType:
496503
endianness_code = endianness_to_numpy_str(self.endianness)
497504
return np.dtype(f"{endianness_code}V{self.length}")
498505

499-
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> tuple[int, ...]:
506+
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
500507
return base64.standard_b64encode(data).decode("ascii")
501508

502509
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
503510
# todo: check that this is well-formed
504-
return self.unwrap().type(bytes(data))
511+
return self.unwrap().type(base64.standard_b64decode(data))
505512

506513

507514
@dataclass(frozen=True, kw_only=True)
508515
class StaticUnicodeString(FlexibleWrapperBase[np.dtypes.StrDType, np.str_]):
509516
name = "numpy/static_unicode_string"
510-
kind = "string"
511-
default_value = np.str_("")
512517
item_size_bits = 32 # UCS4 is 32 bits per code point
513518

519+
def default_value(self) -> np.str_:
520+
return np.str_("")
521+
514522
def to_dict(self) -> dict[str, JSON]:
515523
return {"name": self.name, "configuration": {"length": self.length}}
516524

@@ -528,8 +536,9 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_:
528536
@dataclass(frozen=True, kw_only=True)
529537
class VariableLengthString(DTypeWrapper[np.dtypes.StringDType, str]):
530538
name = "numpy/vlen_string"
531-
kind = "string"
532-
default_value = ""
539+
540+
def default_value(self) -> str:
541+
return ""
533542

534543
@classmethod
535544
def _wrap_unsafe(cls, dtype: np.dtypes.StringDType) -> Self:
@@ -555,10 +564,11 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
555564
@dataclass(frozen=True, kw_only=True)
556565
class VariableLengthString(DTypeWrapper[np.dtypes.ObjectDType, str]):
557566
name = "numpy/vlen_string"
558-
kind = "string"
559-
default_value = np.object_("")
560567
endianness: Endianness = field(default=None)
561568

569+
def default_value(self) -> str:
570+
return ""
571+
562572
def __post_init__(self) -> None:
563573
if self.endianness is not None:
564574
raise ValueError("VariableLengthString does not support endianness.")
@@ -570,24 +580,57 @@ def to_dict(self) -> dict[str, JSON]:
570580
def _wrap_unsafe(cls, dtype: np.dtypes.ObjectDType) -> Self:
571581
return cls()
572582

573-
def unwrap(self) -> np.dtypes.ObjectDType:
574-
return super().unwrap()
575-
576583
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
577584
return str(data)
578585

579586
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
587+
"""
588+
String literals pass through
589+
"""
580590
if not check_json_str(data):
581591
raise TypeError(f"Invalid type: {data}. Expected a string.")
582-
return self.unwrap().type(data)
592+
return data
593+
594+
595+
DateUnit = Literal["Y", "M", "W", "D"]
596+
TimeUnit = Literal["h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as"]
583597

584598

585599
@dataclass(frozen=True, kw_only=True)
586-
class StructuredDtype(DTypeWrapper[np.dtypes.VoidDType, np.void]):
600+
class DateTime64(DTypeWrapper[np.dtypes.DateTime64DType, np.datetime64]):
601+
name = "numpy/datetime64"
602+
unit: DateUnit | TimeUnit
603+
604+
def default_value(self) -> np.datetime64:
605+
return np.datetime64("NaT")
606+
607+
@classmethod
608+
def _wrap_unsafe(cls, dtype: np.dtypes.DateTime64DType) -> Self:
609+
unit = dtype.name[dtype.name.rfind("[") + 1 : dtype.name.rfind("]")]
610+
return cls(unit=unit)
611+
612+
def unwrap(self) -> np.dtypes.DateTime64DType:
613+
return np.dtype(f"datetime64[{self.unit}]").newbyteorder(
614+
endianness_to_numpy_str(self.endianness)
615+
)
616+
617+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.datetime64:
618+
if check_json_int(data):
619+
return datetime_from_json(data, self.unit)
620+
raise TypeError(f"Invalid type: {data}. Expected an integer.")
621+
622+
def to_json_value(self, data: np.datetime64, *, zarr_format: ZarrFormat) -> int:
623+
return datetime_to_json(data)
624+
625+
626+
@dataclass(frozen=True, kw_only=True)
627+
class Structured(DTypeWrapper[np.dtypes.VoidDType, np.void]):
587628
name = "numpy/struct"
588-
kind = "struct"
589629
fields: tuple[tuple[str, DTypeWrapper[Any, Any], int], ...]
590630

631+
def default_value(self) -> np.void:
632+
return np.array([0], dtype=self.unwrap())[0]
633+
591634
@classmethod
592635
def check_dtype(cls, dtype: np.dtypes.DTypeLike) -> TypeGuard[np.dtypes.VoidDType]:
593636
"""
@@ -608,6 +651,9 @@ def _wrap_unsafe(cls, dtype: np.dtypes.VoidDType) -> Self:
608651

609652
return cls(fields=tuple(fields))
610653

654+
def unwrap(self) -> np.dtypes.VoidDType:
655+
return np.dtype([(key, dtype.unwrap()) for (key, dtype, _) in self.fields])
656+
611657
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
612658
return structured_scalar_to_json(data.tobytes(), zarr_format)
613659

@@ -629,7 +675,10 @@ def get_data_type_from_numpy(dtype: npt.DTypeLike) -> DTypeWrapper:
629675
np_dtype = np.dtype(dtype)
630676
data_type_registry.lazy_load()
631677
for val in data_type_registry.contents.values():
632-
return val.wrap(np_dtype)
678+
try:
679+
return val.wrap(np_dtype)
680+
except TypeError:
681+
pass
633682
raise ValueError(
634683
f"numpy dtype '{dtype}' does not have a corresponding Zarr dtype in: {list(data_type_registry.contents)}."
635684
)
@@ -689,11 +738,11 @@ def get(self, key: str) -> type[DTypeWrapper[Any, Any]]:
689738
return self.contents[key]
690739

691740
def match_dtype(self, dtype: npt.DTypeLike) -> DTypeWrapper[Any, Any]:
692-
data_type_registry.lazy_load()
693-
for val in data_type_registry.contents.values():
741+
self.lazy_load()
742+
for val in self.contents.values():
694743
try:
695-
return val._wrap_unsafe(dtype)
696-
except ValueError:
744+
return val.wrap(dtype)
745+
except TypeError:
697746
pass
698747
raise ValueError(f"No data type wrapper found that matches {dtype}")
699748

@@ -708,7 +757,15 @@ def register_data_type(cls: type[DTypeWrapper[Any, Any]]) -> None:
708757
FLOAT_DTYPE = Float16 | Float32 | Float64
709758
COMPLEX_DTYPE = Complex64 | Complex128
710759
STRING_DTYPE = StaticUnicodeString | VariableLengthString | StaticByteString
711-
for dtype in get_args(
712-
Bool | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | StaticRawBytes
713-
):
760+
DTYPE = (
761+
Bool
762+
| INTEGER_DTYPE
763+
| FLOAT_DTYPE
764+
| COMPLEX_DTYPE
765+
| STRING_DTYPE
766+
| StaticRawBytes
767+
| Structured
768+
| DateTime64
769+
)
770+
for dtype in get_args(DTYPE):
714771
register_data_type(dtype)

0 commit comments

Comments
 (0)