Skip to content

Commit b588f70

Browse files
committed
fix dtype sizes, adjust fill value parsing in from_dict, fix tests
1 parent 556e390 commit b588f70

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

src/zarr/core/metadata/dtype.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@ def endianness_to_numpy_str(endianness: Endianness | None) -> Literal[">", "<",
3434
def check_json_bool(data: JSON) -> TypeGuard[bool]:
3535
return bool(isinstance(data, bool))
3636

37+
3738
def check_json_str(data: JSON) -> TypeGuard[str]:
3839
return bool(isinstance(data, str))
3940

41+
4042
def check_json_int(data: JSON) -> TypeGuard[int]:
4143
return bool(isinstance(data, int))
4244

45+
4346
def check_json_float(data: JSON) -> TypeGuard[float]:
4447
if data == "NaN" or data == "Infinity" or data == "-Infinity":
4548
return True
@@ -254,7 +257,7 @@ def from_json_value(
254257
@dataclass(frozen=True, kw_only=True)
255258
class UInt8(DTypeBase):
256259
name = "uint8"
257-
item_size = 2
260+
item_size = 1
258261
kind = "numeric"
259262
numpy_character_code = "B"
260263
default = 0
@@ -488,13 +491,15 @@ def to_numpy(self, *, endianness: Endianness | None = None) -> np.dtypes.Float64
488491
return super().to_numpy(endianness=endianness)
489492

490493
def to_json_value(self, data: np.generic, zarr_format: ZarrFormat) -> float:
491-
return float(data)
494+
return float_to_json(data, zarr_format)
492495

493496
def from_json_value(
494497
self, data: JSON, *, zarr_format: ZarrFormat, endianness: Endianness | None = None
495498
) -> np.float64:
496499
if check_json_float(data):
497-
return float_from_json(data, dtype=self.to_numpy(endianness=endianness))
500+
return float_from_json(
501+
data, dtype=self.to_numpy(endianness=endianness), zarr_format=zarr_format
502+
)
498503
raise TypeError(f"Invalid type: {data}. Expected a float.")
499504

500505

@@ -504,7 +509,7 @@ def from_json_value(
504509
@dataclass(frozen=True, kw_only=True)
505510
class Complex64(DTypeBase):
506511
name = "complex64"
507-
item_size = 16
512+
item_size = 8
508513
kind = "numeric"
509514
numpy_character_code = "F"
510515
default = 0.0 + 0.0j
@@ -533,7 +538,7 @@ def from_json_value(
533538
@dataclass(frozen=True, kw_only=True)
534539
class Complex128(DTypeBase):
535540
name = "complex128"
536-
item_size = 32
541+
item_size = 16
537542
kind = "numeric"
538543
numpy_character_code = "D"
539544
default = 0.0 + 0.0j

src/zarr/core/metadata/v3.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,28 +263,26 @@ def __init__(
263263
Because the class is a frozen dataclass, we set attributes using object.__setattr__
264264
"""
265265
shape_parsed = parse_shapelike(shape)
266-
data_type_parsed = data_type
267266
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
268267
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
269268
dimension_names_parsed = parse_dimension_names(dimension_names)
270-
# we pass a string here rather than an enum to make mypy happy
271-
fill_value_parsed = data_type_parsed.to_numpy().type(fill_value)
269+
fill_value_parsed = data_type.to_numpy().type(fill_value)
272270
attributes_parsed = parse_attributes(attributes)
273271
codecs_parsed_partial = parse_codecs(codecs)
274272
storage_transformers_parsed = parse_storage_transformers(storage_transformers)
275273

276274
array_spec = ArraySpec(
277275
shape=shape_parsed,
278-
dtype=data_type_parsed.to_numpy(),
276+
dtype=data_type.to_numpy(),
279277
fill_value=fill_value_parsed,
280278
config=ArrayConfig.from_dict({}), # TODO: config is not needed here.
281279
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
282280
)
283281
codecs_parsed = tuple(c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial)
284-
validate_codecs(codecs_parsed_partial, data_type_parsed)
282+
validate_codecs(codecs_parsed_partial, data_type)
285283

286284
object.__setattr__(self, "shape", shape_parsed)
287-
object.__setattr__(self, "data_type", data_type_parsed)
285+
object.__setattr__(self, "data_type", data_type)
288286
object.__setattr__(self, "chunk_grid", chunk_grid_parsed)
289287
object.__setattr__(self, "chunk_key_encoding", chunk_key_encoding_parsed)
290288
object.__setattr__(self, "codecs", codecs_parsed)
@@ -405,11 +403,16 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
405403
else:
406404
data_type = get_data_type_from_dict(data_type_json)
407405

406+
# check that the fill value is consistent with the data type
407+
fill_value_parsed = data_type.from_json_value(_data.pop("fill_value"), zarr_format=3)
408+
408409
# dimension_names key is optional, normalize missing to `None`
409410
_data["dimension_names"] = _data.pop("dimension_names", None)
411+
410412
# attributes key is optional, normalize missing to `None`
411413
_data["attributes"] = _data.pop("attributes", None)
412-
return cls(**_data, data_type=data_type) # type: ignore[arg-type]
414+
415+
return cls(**_data, fill_value=fill_value_parsed, data_type=data_type) # type: ignore[arg-type]
413416

414417
def to_dict(self) -> dict[str, JSON]:
415418
out_dict = super().to_dict()

tests/test_metadata/test_v3.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
1313
from zarr.core.config import config
1414
from zarr.core.group import GroupMetadata, parse_node_type
15-
from zarr.core.metadata.dtype import complex_from_json
15+
from zarr.core.metadata.dtype import Flexible, complex_from_json
1616
from zarr.core.metadata.v3 import (
1717
ArrayV3Metadata,
1818
parse_dimension_names,
@@ -278,7 +278,7 @@ async def test_datetime_metadata(fill_value: int, precision: str) -> None:
278278

279279

280280
@pytest.mark.parametrize(
281-
("data_type", "fill_value"), [("uint8", -1), ("int32", 22.5), ("float32", "foo")]
281+
("data_type", "fill_value"), [("uint8", {}), ("int32", [0, 1]), ("float32", "foo")]
282282
)
283283
async def test_invalid_fill_value_raises(data_type: str, fill_value: float) -> None:
284284
metadata_dict = {
@@ -288,10 +288,11 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: float) -> N
288288
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
289289
"data_type": data_type,
290290
"chunk_key_encoding": {"name": "default", "separator": "."},
291-
"codecs": (),
291+
"codecs": ({"name": "bytes"},),
292292
"fill_value": fill_value, # this is not a valid fill value for uint8
293293
}
294-
with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"):
294+
# multiple things can go wrong here, so we don't match on the error message.
295+
with pytest.raises(TypeError):
295296
ArrayV3Metadata.from_dict(metadata_dict)
296297

297298

@@ -323,13 +324,12 @@ async def test_special_float_fill_values(fill_value: str) -> None:
323324

324325
@pytest.mark.parametrize("dtype_str", dtypes)
325326
def test_dtypes(dtype_str: str) -> None:
326-
dt = DataType(dtype_str)
327+
dt = get_data_type_from_numpy(dtype_str)
327328
np_dtype = dt.to_numpy()
328-
if dtype_str not in vlen_dtypes:
329-
# we can round trip "normal" dtypes
330-
assert dt == DataType.from_numpy(np_dtype)
331-
assert dt.byte_count == np_dtype.itemsize
332-
assert dt.has_endianness == (dt.byte_count > 1)
329+
330+
if not isinstance(dt, Flexible):
331+
assert dt.item_size == np_dtype.itemsize
333332
else:
334-
# return type for vlen types may vary depending on numpy version
335-
assert dt.byte_count is None
333+
assert dt.length == np_dtype.itemsize
334+
335+
assert dt.numpy_character_code == np_dtype.char

0 commit comments

Comments
 (0)