|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | from abc import ABC, abstractmethod |
5 | | -from collections.abc import Iterable |
| 5 | +from collections.abc import Iterable, Sequence |
6 | 6 | from dataclasses import dataclass, field, replace |
7 | 7 | from enum import Enum |
8 | | -from typing import TYPE_CHECKING, Any, Literal |
| 8 | +from typing import TYPE_CHECKING, Any, Literal, cast, overload |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import numpy.typing as npt |
|
32 | 32 | ChunkCoords, |
33 | 33 | ZarrFormat, |
34 | 34 | parse_dtype, |
35 | | - parse_fill_value, |
36 | 35 | parse_named_configuration, |
37 | 36 | parse_shapelike, |
38 | 37 | ) |
@@ -189,7 +188,7 @@ def __init__( |
189 | 188 | chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) |
190 | 189 | chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) |
191 | 190 | dimension_names_parsed = parse_dimension_names(dimension_names) |
192 | | - fill_value_parsed = parse_fill_value(fill_value) |
| 191 | + fill_value_parsed = parse_fill_value_v3(fill_value, dtype=data_type_parsed) |
193 | 192 | attributes_parsed = parse_attributes(attributes) |
194 | 193 | codecs_parsed_partial = parse_codecs(codecs) |
195 | 194 |
|
@@ -255,9 +254,18 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: |
255 | 254 | return self.chunk_key_encoding.encode_chunk_key(chunk_coords) |
256 | 255 |
|
257 | 256 | def to_buffer_dict(self) -> dict[str, Buffer]: |
258 | | - def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: |
| 257 | + def _json_convert(o: Any) -> Any: |
259 | 258 | if isinstance(o, np.dtype): |
260 | 259 | return str(o) |
| 260 | + if np.isscalar(o): |
| 261 | + # convert numpy scalar to python type, and pass |
| 262 | + # python types through |
| 263 | + out = getattr(o, "item", lambda: o)() |
| 264 | + if isinstance(out, complex): |
| 265 | + # python complex types are not JSON serializable, so we use the |
| 266 | + # serialization defined in the zarr v3 spec |
| 267 | + return [out.real, out.imag] |
| 268 | + return out |
261 | 269 | if isinstance(o, Enum): |
262 | 270 | return o.name |
263 | 271 | # this serializes numcodecs compressors |
@@ -341,7 +349,7 @@ def __init__( |
341 | 349 | order_parsed = parse_indexing_order(order) |
342 | 350 | dimension_separator_parsed = parse_separator(dimension_separator) |
343 | 351 | filters_parsed = parse_filters(filters) |
344 | | - fill_value_parsed = parse_fill_value(fill_value) |
| 352 | + fill_value_parsed = parse_fill_value_v2(fill_value, dtype=data_type_parsed) |
345 | 353 | attributes_parsed = parse_attributes(attributes) |
346 | 354 |
|
347 | 355 | object.__setattr__(self, "shape", shape_parsed) |
@@ -371,13 +379,17 @@ def chunks(self) -> ChunkCoords: |
371 | 379 |
|
372 | 380 | def to_buffer_dict(self) -> dict[str, Buffer]: |
373 | 381 | def _json_convert( |
374 | | - o: np.dtype[Any], |
375 | | - ) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]: |
| 382 | + o: Any, |
| 383 | + ) -> Any: |
376 | 384 | if isinstance(o, np.dtype): |
377 | 385 | if o.fields is None: |
378 | 386 | return o.str |
379 | 387 | else: |
380 | 388 | return o.descr |
| 389 | + if np.isscalar(o): |
| 390 | + # convert numpy scalar to python type, and pass |
| 391 | + # python types through |
| 392 | + return getattr(o, "item", lambda: o)() |
381 | 393 | raise TypeError |
382 | 394 |
|
383 | 395 | zarray_dict = self.to_dict() |
@@ -517,3 +529,105 @@ def parse_codecs(data: Iterable[Codec | dict[str, JSON]]) -> tuple[Codec, ...]: |
517 | 529 | out += (get_codec_class(name_parsed).from_dict(c),) |
518 | 530 |
|
519 | 531 | return out |
| 532 | + |
| 533 | + |
| 534 | +def parse_fill_value_v2(fill_value: Any, dtype: np.dtype[Any]) -> Any: |
| 535 | + """ |
| 536 | + Parse a potential fill value into a value that is compatible with the provided dtype. |
| 537 | +
|
| 538 | + This is a light wrapper around zarr.v2.util.normalize_fill_value. |
| 539 | +
|
| 540 | + Parameters |
| 541 | + ---------- |
| 542 | + fill_value: Any |
| 543 | + A potential fill value. |
| 544 | + dtype: np.dtype[Any] |
| 545 | + A numpy dtype. |
| 546 | +
|
| 547 | + Returns |
| 548 | + An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) |
| 549 | + """ |
| 550 | + from zarr.v2.util import normalize_fill_value |
| 551 | + |
| 552 | + return normalize_fill_value(fill_value=fill_value, dtype=dtype) |
| 553 | + |
| 554 | + |
| 555 | +BOOL = np.bool_ |
| 556 | +BOOL_DTYPE = np.dtypes.BoolDType |
| 557 | + |
| 558 | +INTEGER_DTYPE = ( |
| 559 | + np.dtypes.Int8DType |
| 560 | + | np.dtypes.Int16DType |
| 561 | + | np.dtypes.Int32DType |
| 562 | + | np.dtypes.Int64DType |
| 563 | + | np.dtypes.UByteDType |
| 564 | + | np.dtypes.UInt16DType |
| 565 | + | np.dtypes.UInt32DType |
| 566 | + | np.dtypes.UInt64DType |
| 567 | +) |
| 568 | + |
| 569 | +INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 |
| 570 | +FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType |
| 571 | +FLOAT = np.float16 | np.float32 | np.float64 |
| 572 | +COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType |
| 573 | +COMPLEX = np.complex64 | np.complex128 |
| 574 | +# todo: r* dtypes |
| 575 | + |
| 576 | + |
| 577 | +@overload |
| 578 | +def parse_fill_value_v3(fill_value: Any, dtype: BOOL_DTYPE) -> BOOL: ... |
| 579 | + |
| 580 | + |
| 581 | +@overload |
| 582 | +def parse_fill_value_v3(fill_value: Any, dtype: INTEGER_DTYPE) -> INTEGER: ... |
| 583 | + |
| 584 | + |
| 585 | +@overload |
| 586 | +def parse_fill_value_v3(fill_value: Any, dtype: FLOAT_DTYPE) -> FLOAT: ... |
| 587 | + |
| 588 | + |
| 589 | +@overload |
| 590 | +def parse_fill_value_v3(fill_value: Any, dtype: COMPLEX_DTYPE) -> COMPLEX: ... |
| 591 | + |
| 592 | + |
| 593 | +def parse_fill_value_v3( |
| 594 | + fill_value: Any, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE |
| 595 | +) -> BOOL | INTEGER | FLOAT | COMPLEX: |
| 596 | + """ |
| 597 | + Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. |
| 598 | + If `fill_value` is `None`, then this function will return the result of casting the value 0 |
| 599 | + to the provided data type. Otherwise, `fill_value` will be cast to the provided data type. |
| 600 | +
|
| 601 | + Note that some numpy dtypes use very permissive casting rules. For example, |
| 602 | + `np.bool_({'not remotely a bool'})` returns `True`. Thus this function should not be used for |
| 603 | + validating that the provided fill value is a valid instance of the data type. |
| 604 | +
|
| 605 | + Parameters |
| 606 | + ---------- |
| 607 | + fill_value: Any |
| 608 | + A potential fill value. |
| 609 | + dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE |
| 610 | + A numpy data type that models a data type defined in the Zarr V3 specification. |
| 611 | +
|
| 612 | + Returns |
| 613 | + ------- |
| 614 | + A scalar instance of `dtype` |
| 615 | + """ |
| 616 | + if fill_value is None: |
| 617 | + return dtype.type(0) |
| 618 | + if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): |
| 619 | + if dtype in (np.complex64, np.complex128): |
| 620 | + dtype = cast(COMPLEX_DTYPE, dtype) |
| 621 | + if len(fill_value) == 2: |
| 622 | + # complex datatypes serialize to JSON arrays with two elements |
| 623 | + return dtype.type(complex(*fill_value)) |
| 624 | + else: |
| 625 | + msg = ( |
| 626 | + f"Got an invalid fill value for complex data type {dtype}." |
| 627 | + f"Expected a sequence with 2 elements, but {fill_value} has " |
| 628 | + f"length {len(fill_value)}." |
| 629 | + ) |
| 630 | + raise ValueError(msg) |
| 631 | + msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}." |
| 632 | + raise TypeError(msg) |
| 633 | + return dtype.type(fill_value) |
0 commit comments