Skip to content

Commit b7fe986

Browse files
committed
remove endianness kwarg to methods, make it an instance variable instead
1 parent 3c232a4 commit b7fe986

File tree

1 file changed

+42
-65
lines changed

1 file changed

+42
-65
lines changed

src/zarr/core/metadata/dtype.py

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Sequence
3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, replace
44
from typing import Any, ClassVar, Generic, Literal, Self, TypeGuard, TypeVar, cast, get_args
55

66
import numpy as np
@@ -199,11 +199,13 @@ def complex_from_json(
199199
TScalar = TypeVar("TScalar", bound=np.generic)
200200

201201

202+
@dataclass(frozen=True, kw_only=True)
202203
class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
203204
name: ClassVar[str]
204205
dtype_cls: ClassVar[type[TDType]] # this class will create a numpy dtype
205206
kind: ClassVar[DataTypeFlavor]
206-
default_value: TScalar
207+
default_value: ClassVar[TScalar]
208+
endianness: Endianness = "native"
207209

208210
def __init_subclass__(cls) -> None:
209211
# Subclasses will bind the first generic type parameter to an attribute of the class
@@ -215,18 +217,21 @@ def __init_subclass__(cls) -> None:
215217
def to_dict(self) -> dict[str, JSON]:
216218
return {"name": self.name}
217219

218-
def cast_value(self: Self, value: object, *, endianness: Endianness | None = None) -> TScalar:
219-
return cast(np.generic, self.unwrap(endianness=endianness).type(value))
220+
def cast_value(self: Self, value: object) -> TScalar:
221+
return cast(np.generic, self.unwrap().type(value))
220222

221223
@classmethod
222224
@abstractmethod
223225
def wrap(cls: type[Self], dtype: TDType) -> Self:
224226
raise NotImplementedError
225227

226-
def unwrap(self: Self, *, endianness: Endianness | None = None) -> TDType:
227-
endian_str = endianness_to_numpy_str(endianness)
228+
def unwrap(self: Self) -> TDType:
229+
endian_str = endianness_to_numpy_str(self.endianness)
228230
return self.dtype_cls().newbyteorder(endian_str)
229231

232+
def with_endianness(self: Self, endianness: Endianness) -> Self:
233+
return replace(self, endianness=endianness)
234+
230235
@abstractmethod
231236
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> JSON:
232237
"""
@@ -235,9 +240,7 @@ def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> JSON:
235240
raise NotImplementedError
236241

237242
@abstractmethod
238-
def from_json_value(
239-
self: Self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
240-
) -> TScalar:
243+
def from_json_value(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
241244
"""
242245
Read a JSON-serializable value as a numpy scalar
243246
"""
@@ -257,11 +260,9 @@ def wrap(cls, dtype: np.dtypes.BoolDType) -> Self:
257260
def to_json_value(self, data: np.generic, zarr_format: ZarrFormat) -> bool:
258261
return bool(data)
259262

260-
def from_json_value(
261-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
262-
) -> np.bool_:
263+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bool_:
263264
if check_json_bool(data):
264-
return self.unwrap(endianness=endianness).type(data)
265+
return self.unwrap().type(data)
265266
raise TypeError(f"Invalid type: {data}. Expected a boolean.")
266267

267268

@@ -275,11 +276,9 @@ def wrap(cls, dtype: TDType) -> Self:
275276
def to_json_value(self, data: np.generic, zarr_format: ZarrFormat) -> int:
276277
return int(data)
277278

278-
def from_json_value(
279-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
280-
) -> TScalar:
279+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
281280
if check_json_int(data):
282-
return self.unwrap(endianness=endianness).type(data)
281+
return self.unwrap().type(data)
283282
raise TypeError(f"Invalid type: {data}. Expected an integer.")
284283

285284

@@ -341,11 +340,9 @@ def wrap(cls, dtype: TDType) -> Self:
341340
def to_json_value(self, data: np.generic, zarr_format: ZarrFormat) -> JSONFloat:
342341
return float_to_json(data, zarr_format)
343342

344-
def from_json_value(
345-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
346-
) -> TScalar:
343+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar:
347344
if check_json_float_v2(data):
348-
return self.unwrap(endianness=endianness).type(float_from_json(data, zarr_format))
345+
return self.unwrap().type(float_from_json(data, zarr_format))
349346
raise TypeError(f"Invalid type: {data}. Expected a float.")
350347

351348

@@ -382,13 +379,9 @@ def to_json_value(
382379
) -> tuple[JSONFloat, JSONFloat]:
383380
return complex_to_json(data, zarr_format)
384381

385-
def from_json_value(
386-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
387-
) -> np.complex64:
382+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex64:
388383
if check_json_complex_float_v3(data):
389-
return complex_from_json(
390-
data, dtype=self.unwrap(endianness=endianness), zarr_format=zarr_format
391-
)
384+
return complex_from_json(data, dtype=self.unwrap(), zarr_format=zarr_format)
392385
raise TypeError(f"Invalid type: {data}. Expected a complex float.")
393386

394387

@@ -407,13 +400,9 @@ def to_json_value(
407400
) -> tuple[JSONFloat, JSONFloat]:
408401
return complex_to_json(data, zarr_format)
409402

410-
def from_json_value(
411-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
412-
) -> np.complex128:
403+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex128:
413404
if check_json_complex_float_v3(data):
414-
return complex_from_json(
415-
data, dtype=self.unwrap(endianness=endianness), zarr_format=zarr_format
416-
)
405+
return complex_from_json(data, dtype=self.unwrap(), zarr_format=zarr_format)
417406
raise TypeError(f"Invalid type: {data}. Expected a complex float.")
418407

419408

@@ -426,31 +415,27 @@ class FlexibleWrapperBase(DTypeWrapper[TDType, TScalar]):
426415
def wrap(cls, dtype: TDType) -> Self:
427416
return cls(length=dtype.itemsize // (cls.item_size_bits // 8))
428417

429-
def unwrap(self, endianness: Endianness | None = None) -> TDType:
430-
endianness_code = endianness_to_numpy_str(endianness)
418+
def unwrap(self) -> TDType:
419+
endianness_code = endianness_to_numpy_str(self.endianness)
431420
return self.dtype_cls(self.length).newbyteorder(endianness_code)
432421

433422

434423
@dataclass(frozen=True, kw_only=True)
435424
class StaticByteString(FlexibleWrapperBase[np.dtypes.BytesDType, np.bytes_]):
436425
name = "numpy/static_byte_string"
437426
kind = "string"
438-
default_value = b""
427+
default_value = np.bytes_(0)
439428
item_size_bits = 8
440429

441430
def to_dict(self) -> dict[str, JSON]:
442431
return {"name": self.name, "configuration": {"length": self.length}}
443432

444-
def to_json_value(
445-
self, data: np.generic, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
446-
) -> str:
433+
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
447434
return data.tobytes().decode("ascii")
448435

449-
def from_json_value(
450-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
451-
) -> np.bytes_:
436+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_:
452437
if check_json_bool(data):
453-
return self.unwrap(endianness=endianness).type(data.encode("ascii"))
438+
return self.unwrap().type(data.encode("ascii"))
454439
raise TypeError(f"Invalid type: {data}. Expected a string.")
455440

456441

@@ -464,20 +449,18 @@ class StaticRawBytes(FlexibleWrapperBase[np.dtypes.VoidDType, np.void]):
464449
def to_dict(self) -> dict[str, JSON]:
465450
return {"name": f"r{self.length * self.item_size_bits}"}
466451

467-
def unwrap(self, endianness: Endianness | None = None) -> np.dtypes.VoidDType:
452+
def unwrap(self) -> np.dtypes.VoidDType:
468453
# this needs to be overridden because numpy does not allow creating a void type
469454
# by invoking np.dtypes.VoidDType directly
470-
endianness_code = endianness_to_numpy_str(endianness)
455+
endianness_code = endianness_to_numpy_str(self.endianness)
471456
return np.dtype(f"{endianness_code}V{self.length}")
472457

473458
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> tuple[int, ...]:
474459
return tuple(*data.tobytes())
475460

476-
def from_json_value(
477-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
478-
) -> np.void:
461+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
479462
# todo: check that this is well-formed
480-
return self.unwrap(endianness=endianness).type(bytes(data))
463+
return self.unwrap().type(bytes(data))
481464

482465

483466
@dataclass(frozen=True, kw_only=True)
@@ -493,12 +476,10 @@ def to_dict(self) -> dict[str, JSON]:
493476
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
494477
return str(data)
495478

496-
def from_json_value(
497-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
498-
) -> np.str_:
479+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_:
499480
if not check_json_str(data):
500481
raise TypeError(f"Invalid type: {data}. Expected a string.")
501-
return self.unwrap(endianness=endianness).type(data)
482+
return self.unwrap().type(data)
502483

503484

504485
if _NUMPY_SUPPORTS_VLEN_STRING:
@@ -516,17 +497,15 @@ def wrap(cls, dtype: np.dtypes.StringDType) -> Self:
516497
def to_dict(self) -> dict[str, JSON]:
517498
return {"name": self.name}
518499

519-
def unwrap(self, endianness: Endianness | None = None) -> np.dtypes.StringDType:
520-
endianness_code = endianness_to_numpy_str(endianness)
500+
def unwrap(self) -> np.dtypes.StringDType:
501+
endianness_code = endianness_to_numpy_str(self.endianness)
521502
return np.dtype(endianness_code + self.numpy_character_code)
522503

523504
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
524505
return str(data)
525506

526-
def from_json_value(
527-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
528-
) -> str:
529-
return self.unwrap(endianness=endianness).type(data)
507+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
508+
return self.unwrap().type(data)
530509

531510
else:
532511

@@ -543,16 +522,14 @@ def to_dict(self) -> dict[str, JSON]:
543522
def wrap(cls, dtype: np.dtypes.ObjectDType) -> Self:
544523
return cls()
545524

546-
def unwrap(self, endianness: Endianness | None = None) -> np.dtype[np.dtypes.ObjectDType]:
547-
return super().unwrap(endianness=endianness)
525+
def unwrap(self) -> np.dtype[np.dtypes.ObjectDType]:
526+
return super().unwrap()
548527

549528
def to_json_value(self, data: np.generic, *, zarr_format: ZarrFormat) -> str:
550529
return str(data)
551530

552-
def from_json_value(
553-
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
554-
) -> str:
555-
return self.unwrap(endianness=endianness).type(data)
531+
def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
532+
return self.unwrap().type(data)
556533

557534

558535
def resolve_dtype(dtype: npt.DTypeLike | DTypeWrapper | dict[str, JSON]) -> DTypeWrapper:

0 commit comments

Comments
 (0)