Skip to content

Commit 2868994

Browse files
committed
more tests, fix void type default value logic
1 parent c1a8566 commit 2868994

File tree

9 files changed

+185
-189
lines changed

9 files changed

+185
-189
lines changed

src/zarr/core/array.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@
103103
)
104104
from zarr.core.metadata.dtype import (
105105
DTypeWrapper,
106-
StaticByteString,
106+
FixedLengthAsciiString,
107+
FixedLengthUnicodeString,
107108
VariableLengthString,
108109
get_data_type_from_numpy,
109110
)
@@ -710,7 +711,7 @@ def _create_metadata_v3(
710711

711712
if fill_value is None:
712713
# v3 spec will not allow a null fill value
713-
fill_value_parsed = dtype.default_value
714+
fill_value_parsed = dtype.default_value()
714715
else:
715716
fill_value_parsed = fill_value
716717

@@ -4237,7 +4238,7 @@ def _get_default_chunk_encoding_v3(
42374238

42384239
if isinstance(dtype, VariableLengthString):
42394240
serializer = VLenUTF8Codec()
4240-
elif isinstance(dtype, StaticByteString):
4241+
elif isinstance(dtype, FixedLengthAsciiString):
42414242
serializer = VLenBytesCodec()
42424243
else:
42434244
if dtype.unwrap().itemsize == 1:
@@ -4257,9 +4258,9 @@ def _get_default_chunk_encoding_v2(
42574258
from numcodecs import VLenUTF8 as numcodecs_VLenUTF8
42584259
from numcodecs import Zstd as numcodecs_zstd
42594260

4260-
if isinstance(dtype, VariableLengthString):
4261+
if isinstance(dtype, VariableLengthString | FixedLengthUnicodeString):
42614262
filters = (numcodecs_VLenUTF8(),)
4262-
elif isinstance(dtype, StaticByteString):
4263+
elif isinstance(dtype, FixedLengthAsciiString):
42634264
filters = (numcodecs_VLenBytes(),)
42644265
else:
42654266
filters = None

src/zarr/core/buffer/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
472472
return np.array_equal(
473473
self._data,
474474
other,
475-
equal_nan=equal_nan if self._data.dtype.kind not in "USTOV" else False,
475+
equal_nan=equal_nan
476+
if self._data.dtype.kind not in ("U", "S", "T", "O", "V")
477+
else False,
476478
)
477479

478480
def fill(self, value: Any) -> None:

src/zarr/core/codec_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
6363
# validated when decoding the metadata, but we support reading
6464
# Zarr V2 data and need to support the case where fill_value
6565
# is None.
66-
return chunk_spec.dtype.default_value
66+
return chunk_spec.dtype.default_value()
6767
else:
6868
return fill_value
6969

src/zarr/core/metadata/dtype.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TypeVar,
1717
cast,
1818
get_args,
19+
get_origin,
1920
)
2021

2122
import numpy as np
@@ -133,7 +134,7 @@ def float_to_json_v2(data: float | np.floating[Any]) -> JSONFloat:
133134

134135
def float_to_json_v3(data: float | np.floating[Any]) -> JSONFloat:
135136
# v3 can in principle handle distinct NaN values, but numpy does not represent these explicitly
136-
# so we just re-use the v2 routine here
137+
# so we just reuse the v2 routine here
137138
return float_to_json_v2(data)
138139

139140

@@ -148,11 +149,11 @@ def float_to_json(data: float | np.floating[Any], zarr_format: ZarrFormat) -> JS
148149
raise ValueError(f"Invalid zarr format: {zarr_format}. Expected 2 or 3.")
149150

150151

151-
def complex_to_json_v2(data: complex | np.complexfloating[Any]) -> tuple[JSONFloat, JSONFloat]:
152+
def complex_to_json_v2(data: complex | np.complexfloating[Any, Any]) -> tuple[JSONFloat, JSONFloat]:
152153
return float_to_json_v2(data.real), float_to_json_v2(data.imag)
153154

154155

155-
def complex_to_json_v3(data: complex | np.complexfloating[Any]) -> tuple[JSONFloat, JSONFloat]:
156+
def complex_to_json_v3(data: complex | np.complexfloating[Any, Any]) -> tuple[JSONFloat, JSONFloat]:
156157
return float_to_json_v3(data.real), float_to_json_v3(data.imag)
157158

158159

@@ -226,15 +227,16 @@ def complex_from_json(
226227

227228

228229
def datetime_to_json(data: np.datetime64[Any]) -> int:
229-
return data.view("int").item()
230+
return data.view(np.int64).item()
230231

231232

232233
def datetime_from_json(data: int, unit: DateUnit | TimeUnit) -> np.datetime64[Any]:
233234
return np.int64(data).view(f"datetime64[{unit}]")
234235

235236

237+
TScalar = TypeVar("TScalar", bound=np.generic | str, covariant=True)
238+
# TODO: figure out an interface or protocol that non-numpy dtypes can
236239
TDType = TypeVar("TDType", bound=np.dtype[Any])
237-
TScalar = TypeVar("TScalar", bound=np.generic | str)
238240

239241

240242
@dataclass(frozen=True, kw_only=True)
@@ -244,17 +246,27 @@ class DTypeWrapper(Generic[TDType, TScalar], ABC, Metadata):
244246
endianness: Endianness | None = "native"
245247

246248
def __init_subclass__(cls) -> None:
247-
# Subclasses will bind the first generic type parameter to an attribute of the class
248249
# TODO: wrap this in some *very informative* error handling
249250
generic_args = get_args(get_original_bases(cls)[0])
250-
cls.dtype_cls = generic_args[0]
251+
# the logic here is that if a subclass was created with generic type parameters
252+
# specified explicitly, then we bind that type parameter to the dtype_cls attribute
253+
if len(generic_args) > 0:
254+
cls.dtype_cls = generic_args[0]
255+
else:
256+
# but if the subclass was created without generic type parameters specified explicitly,
257+
# then we check the parent DTypeWrapper classes and retrieve their generic type parameters
258+
for base in cls.__orig_bases__:
259+
if get_origin(base) is DTypeWrapper:
260+
generic_args = get_args(base)
261+
cls.dtype_cls = generic_args[0]
262+
break
251263
return super().__init_subclass__()
252264

253265
def to_dict(self) -> dict[str, JSON]:
254266
return {"name": self.name}
255267

256268
def cast_value(self: Self, value: object) -> TScalar:
257-
return cast(np.generic, self.unwrap().type(value))
269+
return cast(TScalar, self.unwrap().type(value))
258270

259271
@abstractmethod
260272
def default_value(self) -> TScalar: ...
@@ -455,7 +467,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.complex1
455467
@dataclass(frozen=True, kw_only=True)
456468
class FlexibleWrapperBase(DTypeWrapper[TDType, TScalar]):
457469
item_size_bits: ClassVar[int]
458-
length: int
470+
length: int = 0
459471

460472
@classmethod
461473
def _wrap_unsafe(cls, dtype: TDType) -> Self:
@@ -467,7 +479,7 @@ def unwrap(self) -> TDType:
467479

468480

469481
@dataclass(frozen=True, kw_only=True)
470-
class StaticByteString(FlexibleWrapperBase[np.dtypes.BytesDType, np.bytes_]):
482+
class FixedLengthAsciiString(FlexibleWrapperBase[np.dtypes.BytesDType, np.bytes_]):
471483
name = "numpy/static_byte_string"
472484
item_size_bits = 8
473485

@@ -492,11 +504,18 @@ class StaticRawBytes(FlexibleWrapperBase[np.dtypes.VoidDType, np.void]):
492504
item_size_bits = 8
493505

494506
def default_value(self) -> np.void:
495-
return np.void(b"")
507+
return self.cast_value(("\x00" * self.length).encode("ascii"))
496508

497509
def to_dict(self) -> dict[str, JSON]:
498510
return {"name": f"r{self.length * self.item_size_bits}"}
499511

512+
@classmethod
513+
def check_dtype(cls: type[Self], dtype: TDType) -> TypeGuard[TDType]:
514+
"""
515+
Reject structured dtypes by ensuring that dtype.fields is None
516+
"""
517+
return type(dtype) is cls.dtype_cls and dtype.fields is None
518+
500519
def unwrap(self) -> np.dtypes.VoidDType:
501520
# this needs to be overridden because numpy does not allow creating a void type
502521
# by invoking np.dtypes.VoidDType directly
@@ -512,7 +531,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
512531

513532

514533
@dataclass(frozen=True, kw_only=True)
515-
class StaticUnicodeString(FlexibleWrapperBase[np.dtypes.StrDType, np.str_]):
534+
class FixedLengthUnicodeString(FlexibleWrapperBase[np.dtypes.StrDType, np.str_]):
516535
name = "numpy/static_unicode_string"
517536
item_size_bits = 32 # UCS4 is 32 bits per code point
518537

@@ -599,7 +618,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> str:
599618
@dataclass(frozen=True, kw_only=True)
600619
class DateTime64(DTypeWrapper[np.dtypes.DateTime64DType, np.datetime64]):
601620
name = "numpy/datetime64"
602-
unit: DateUnit | TimeUnit
621+
unit: DateUnit | TimeUnit = "s"
603622

604623
def default_value(self) -> np.datetime64:
605624
return np.datetime64("NaT")
@@ -609,6 +628,9 @@ def _wrap_unsafe(cls, dtype: np.dtypes.DateTime64DType) -> Self:
609628
unit = dtype.name[dtype.name.rfind("[") + 1 : dtype.name.rfind("]")]
610629
return cls(unit=unit)
611630

631+
def cast_value(self, value: object) -> np.datetime64:
632+
return self.unwrap().type(value, self.unit)
633+
612634
def unwrap(self) -> np.dtypes.DateTime64DType:
613635
return np.dtype(f"datetime64[{self.unit}]").newbyteorder(
614636
endianness_to_numpy_str(self.endianness)
@@ -651,6 +673,26 @@ def _wrap_unsafe(cls, dtype: np.dtypes.VoidDType) -> Self:
651673

652674
return cls(fields=tuple(fields))
653675

676+
def to_dict(self) -> dict[str, JSON]:
677+
base_dict = super().to_dict()
678+
if base_dict.get("configuration", {}) != {}:
679+
raise ValueError(
680+
"This data type wrapper cannot inherit from a data type wrapper that defines a configuration for its dict serialization"
681+
)
682+
field_configs = [
683+
(f_name, f_dtype.to_dict(), f_offset) for f_name, f_dtype, f_offset in self.fields
684+
]
685+
base_dict["configuration"] = {"fields": field_configs}
686+
return base_dict
687+
688+
@classmethod
689+
def from_dict(cls, data: dict[str, JSON]) -> Self:
690+
fields = tuple(
691+
(f_name, get_data_type_from_dict(f_dtype), f_offset)
692+
for f_name, f_dtype, f_offset in data["fields"]
693+
)
694+
return cls(fields=fields)
695+
654696
def unwrap(self) -> np.dtypes.VoidDType:
655697
return np.dtype([(key, dtype.unwrap()) for (key, dtype, _) in self.fields])
656698

@@ -665,7 +707,7 @@ def from_json_value(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:
665707
return np.array([as_bytes], dtype=dtype.str).view(dtype)[0]
666708

667709

668-
def get_data_type_from_numpy(dtype: npt.DTypeLike) -> DTypeWrapper:
710+
def get_data_type_from_numpy(dtype: npt.DTypeLike) -> DTypeWrapper[Any, Any]:
669711
if dtype in (str, "str"):
670712
if _NUMPY_SUPPORTS_VLEN_STRING:
671713
np_dtype = np.dtype("T")
@@ -674,17 +716,10 @@ def get_data_type_from_numpy(dtype: npt.DTypeLike) -> DTypeWrapper:
674716
else:
675717
np_dtype = np.dtype(dtype)
676718
data_type_registry.lazy_load()
677-
for val in data_type_registry.contents.values():
678-
try:
679-
return val.wrap(np_dtype)
680-
except TypeError:
681-
pass
682-
raise ValueError(
683-
f"numpy dtype '{dtype}' does not have a corresponding Zarr dtype in: {list(data_type_registry.contents)}."
684-
)
719+
return data_type_registry.match_dtype(np_dtype)
685720

686721

687-
def get_data_type_from_dict(dtype: dict[str, JSON]) -> DTypeWrapper:
722+
def get_data_type_from_dict(dtype: dict[str, JSON]) -> DTypeWrapper[Any.Any]:
688723
data_type_registry.lazy_load()
689724
dtype_name = dtype["name"]
690725
dtype_cls = data_type_registry.get(dtype_name)
@@ -737,14 +772,14 @@ def register(self: Self, cls: type[DTypeWrapper[Any, Any]]) -> None:
737772
def get(self, key: str) -> type[DTypeWrapper[Any, Any]]:
738773
return self.contents[key]
739774

740-
def match_dtype(self, dtype: npt.DTypeLike) -> DTypeWrapper[Any, Any]:
775+
def match_dtype(self, dtype: TDType) -> DTypeWrapper[Any, Any]:
741776
self.lazy_load()
742777
for val in self.contents.values():
743778
try:
744779
return val.wrap(dtype)
745780
except TypeError:
746781
pass
747-
raise ValueError(f"No data type wrapper found that matches {dtype}")
782+
raise ValueError(f"No data type wrapper found that matches dtype '{dtype}'")
748783

749784

750785
def register_data_type(cls: type[DTypeWrapper[Any, Any]]) -> None:
@@ -756,7 +791,7 @@ def register_data_type(cls: type[DTypeWrapper[Any, Any]]) -> None:
756791
INTEGER_DTYPE = Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
757792
FLOAT_DTYPE = Float16 | Float32 | Float64
758793
COMPLEX_DTYPE = Complex64 | Complex128
759-
STRING_DTYPE = StaticUnicodeString | VariableLengthString | StaticByteString
794+
STRING_DTYPE = FixedLengthUnicodeString | VariableLengthString | FixedLengthAsciiString
760795
DTYPE = (
761796
Bool
762797
| INTEGER_DTYPE

src/zarr/core/metadata/v2.py

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
import base64
44
import warnings
55
from collections.abc import Iterable
6-
from enum import Enum
76
from typing import TYPE_CHECKING, TypedDict, cast
87

98
import numcodecs.abc
109

1110
from zarr.abc.metadata import Metadata
1211
from zarr.core.metadata.dtype import (
1312
DTypeWrapper,
14-
StaticByteString,
15-
StaticRawBytes,
13+
Structured,
1614
get_data_type_from_numpy,
1715
)
1816

@@ -109,49 +107,12 @@ def shards(self) -> ChunkCoords | None:
109107
return None
110108

111109
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
112-
def _json_convert(
113-
o: Any,
114-
) -> Any:
115-
if isinstance(o, np.dtype):
116-
if o.fields is None:
117-
return o.str
118-
else:
119-
return o.descr
120-
if isinstance(o, numcodecs.abc.Codec):
121-
codec_config = o.get_config()
122-
123-
# Hotfix for https://github.com/zarr-developers/zarr-python/issues/2647
124-
if codec_config["id"] == "zstd" and not codec_config.get("checksum", False):
125-
codec_config.pop("checksum", None)
126-
127-
return codec_config
128-
if np.isscalar(o):
129-
out: Any
130-
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
131-
# https://github.com/zarr-developers/zarr-python/issues/2119
132-
# `.item()` on a datetime type might or might not return an
133-
# integer, depending on the value.
134-
# Explicitly cast to an int first, and then grab .item()
135-
out = o.view("i8").item()
136-
else:
137-
# convert numpy scalar to python type, and pass
138-
# python types through
139-
out = getattr(o, "item", lambda: o)()
140-
if isinstance(out, complex):
141-
# python complex types are not JSON serializable, so we use the
142-
# serialization defined in the zarr v3 spec
143-
return [out.real, out.imag]
144-
return out
145-
if isinstance(o, Enum):
146-
return o.name
147-
raise TypeError
148-
149110
zarray_dict = self.to_dict()
150111
zattrs_dict = zarray_dict.pop("attributes", {})
151112
json_indent = config.get("json_indent")
152113
return {
153114
ZARRAY_JSON: prototype.buffer.from_bytes(
154-
json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode()
115+
json.dumps(zarray_dict, indent=json_indent).encode()
155116
),
156117
ZATTRS_JSON: prototype.buffer.from_bytes(
157118
json.dumps(zattrs_dict, indent=json_indent).encode()
@@ -196,11 +157,19 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
196157

197158
def to_dict(self) -> dict[str, JSON]:
198159
zarray_dict = super().to_dict()
160+
if isinstance(zarray_dict["compressor"], numcodecs.abc.Codec):
161+
zarray_dict["compressor"] = zarray_dict["compressor"].get_config()
162+
if zarray_dict["filters"] is not None:
163+
raw_filters = zarray_dict["filters"]
164+
new_filters = []
165+
for f in raw_filters:
166+
if isinstance(f, numcodecs.abc.Codec):
167+
new_filters.append(f.get_config())
168+
else:
169+
new_filters.append(f)
170+
zarray_dict["filters"] = new_filters
199171

200-
if (
201-
isinstance(self.dtype, StaticByteString | StaticRawBytes)
202-
and self.fill_value is not None
203-
):
172+
if self.fill_value is not None:
204173
# There's a relationship between self.dtype and self.fill_value
205174
# that mypy isn't aware of. The fact that we have S or V dtype here
206175
# means we should have a bytes-type fill_value.
@@ -209,10 +178,7 @@ def to_dict(self) -> dict[str, JSON]:
209178

210179
_ = zarray_dict.pop("dtype")
211180
dtype_json: JSON
212-
# TODO: Replace this with per-dtype method
213-
# In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string
214-
dtype_descr = self.dtype.unwrap().descr
215-
if self.dtype.unwrap().kind == "V" and dtype_descr[0][0] != "" and len(dtype_descr) != 0:
181+
if isinstance(self.dtype, Structured):
216182
dtype_json = tuple(self.dtype.unwrap().descr)
217183
else:
218184
dtype_json = self.dtype.unwrap().str

0 commit comments

Comments
 (0)