diff --git a/changes/2874.feature.rst b/changes/2874.feature.rst new file mode 100644 index 0000000000..4c50532ae0 --- /dev/null +++ b/changes/2874.feature.rst @@ -0,0 +1,9 @@ +Adds zarr-specific data type classes. This replaces the internal use of numpy data types for zarr +v2 and a fixed set of string enums for zarr v3. This change is largely internal, but it does +change the type of the ``dtype`` and ``data_type`` fields on the ``ArrayV2Metadata`` and +``ArrayV3Metadata`` classes. It also changes the JSON metadata representation of the +variable-length string data type, but the old metadata representation can still be +used when reading arrays. The logic for automatically choosing the chunk encoding for a given data +type has also changed, and this necessitated changes to the ``config`` API. + +For more on this new feature, see the `documentation `_ \ No newline at end of file diff --git a/docs/user-guide/arrays.rst b/docs/user-guide/arrays.rst index 5bd6b1500f..c27f1296b9 100644 --- a/docs/user-guide/arrays.rst +++ b/docs/user-guide/arrays.rst @@ -182,7 +182,7 @@ which can be used to print useful diagnostics, e.g.:: >>> z.info Type : Array Zarr format : 3 - Data type : DataType.int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (10000, 10000) Chunk shape : (1000, 1000) @@ -200,7 +200,7 @@ prints additional diagnostics, e.g.:: >>> z.info_complete() Type : Array Zarr format : 3 - Data type : DataType.int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (10000, 10000) Chunk shape : (1000, 1000) @@ -248,7 +248,7 @@ built-in delta filter:: The default compressor can be changed by setting the value of the using Zarr's :ref:`user-guide-config`, e.g.:: - >>> with zarr.config.set({'array.v2_default_compressor.numeric': {'id': 'blosc'}}): + >>> with zarr.config.set({'array.v2_default_compressor.default': {'id': 'blosc'}}): ... z = zarr.create_array(store={}, shape=(100000000,), chunks=(1000000,), dtype='int32', zarr_format=2) >>> z.filters () @@ -288,7 +288,7 @@ Here is an example using a delta filter with the Blosc compressor:: >>> z.info Type : Array Zarr format : 3 - Data type : DataType.int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (10000, 10000) Chunk shape : (1000, 1000) @@ -603,7 +603,7 @@ Sharded arrays can be created by providing the ``shards`` parameter to :func:`za >>> a.info_complete() Type : Array Zarr format : 3 - Data type : DataType.uint8 + Data type : UInt8() Fill value : 0 Shape : (10000, 10000) Shard shape : (1000, 1000) @@ -612,10 +612,10 @@ Sharded arrays can be created by providing the ``shards`` parameter to :func:`za Read-only : False Store type : LocalStore Filters : () - Serializer : BytesCodec(endian=) + Serializer : BytesCodec(endian=None) Compressors : (ZstdCodec(level=0, checksum=False),) No. bytes : 100000000 (95.4M) - No. bytes stored : 3981552 + No. bytes stored : 3981473 Storage ratio : 25.1 Shards Initialized : 100 diff --git a/docs/user-guide/config.rst b/docs/user-guide/config.rst index 91ffe50b91..4479e30619 100644 --- a/docs/user-guide/config.rst +++ b/docs/user-guide/config.rst @@ -43,39 +43,30 @@ This is the current default configuration:: >>> zarr.config.pprint() {'array': {'order': 'C', - 'v2_default_compressor': {'bytes': {'checksum': False, - 'id': 'zstd', - 'level': 0}, - 'numeric': {'checksum': False, - 'id': 'zstd', - 'level': 0}, - 'string': {'checksum': False, + 'v2_default_compressor': {'default': {'checksum': False, 'id': 'zstd', - 'level': 0}}, - 'v2_default_filters': {'bytes': [{'id': 'vlen-bytes'}], - 'numeric': None, - 'raw': None, - 'string': [{'id': 'vlen-utf8'}]}, - 'v3_default_compressors': {'bytes': [{'configuration': {'checksum': False, - 'level': 0}, - 'name': 'zstd'}], - 'numeric': [{'configuration': {'checksum': False, + 'level': 0}, + 'variable-length-string': {'checksum': False, + 'id': 'zstd', + 'level': 0}}, + 'v2_default_filters': {'default': None, + 'variable-length-string': [{'id': 'vlen-utf8'}]}, + 'v3_default_compressors': {'default': [{'configuration': {'checksum': False, 'level': 0}, 'name': 'zstd'}], - 'string': [{'configuration': {'checksum': False, - 'level': 0}, - 'name': 'zstd'}]}, - 'v3_default_filters': {'bytes': [], 'numeric': [], 'string': []}, - 'v3_default_serializer': {'bytes': {'name': 'vlen-bytes'}, - 'numeric': {'configuration': {'endian': 'little'}, - 'name': 'bytes'}, - 'string': {'name': 'vlen-utf8'}}, - 'write_empty_chunks': False}, - 'async': {'concurrency': 10, 'timeout': None}, - 'buffer': 'zarr.core.buffer.cpu.Buffer', - 'codec_pipeline': {'batch_size': 1, - 'path': 'zarr.core.codec_pipeline.BatchedCodecPipeline'}, - 'codecs': {'blosc': 'zarr.codecs.blosc.BloscCodec', + 'variable-length-string': [{'configuration': {'checksum': False, + 'level': 0}, + 'name': 'zstd'}]}, + 'v3_default_filters': {'default': [], 'variable-length-string': []}, + 'v3_default_serializer': {'default': {'configuration': {'endian': 'little'}, + 'name': 'bytes'}, + 'variable-length-string': {'name': 'vlen-utf8'}}, + 'write_empty_chunks': False}, + 'async': {'concurrency': 10, 'timeout': None}, + 'buffer': 'zarr.core.buffer.cpu.Buffer', + 'codec_pipeline': {'batch_size': 1, + 'path': 'zarr.core.codec_pipeline.BatchedCodecPipeline'}, + 'codecs': {'blosc': 'zarr.codecs.blosc.BloscCodec', 'bytes': 'zarr.codecs.bytes.BytesCodec', 'crc32c': 'zarr.codecs.crc32c_.Crc32cCodec', 'endian': 'zarr.codecs.bytes.BytesCodec', @@ -85,7 +76,7 @@ This is the current default configuration:: 'vlen-bytes': 'zarr.codecs.vlen_utf8.VLenBytesCodec', 'vlen-utf8': 'zarr.codecs.vlen_utf8.VLenUTF8Codec', 'zstd': 'zarr.codecs.zstd.ZstdCodec'}, - 'default_zarr_format': 3, - 'json_indent': 2, - 'ndbuffer': 'zarr.core.buffer.cpu.NDBuffer', - 'threading': {'max_workers': None}} + 'default_zarr_format': 3, + 'json_indent': 2, + 'ndbuffer': 'zarr.core.buffer.cpu.NDBuffer', + 'threading': {'max_workers': None}} diff --git a/docs/user-guide/consolidated_metadata.rst b/docs/user-guide/consolidated_metadata.rst index edd5bafc8d..4cd72dbc74 100644 --- a/docs/user-guide/consolidated_metadata.rst +++ b/docs/user-guide/consolidated_metadata.rst @@ -47,7 +47,7 @@ that can be used.: >>> from pprint import pprint >>> pprint(dict(sorted(consolidated_metadata.items()))) {'a': ArrayV3Metadata(shape=(1,), - data_type=, + data_type=Float64(endianness='little'), chunk_grid=RegularChunkGrid(chunk_shape=(1,)), chunk_key_encoding=DefaultChunkKeyEncoding(name='default', separator='/'), @@ -60,7 +60,7 @@ that can be used.: node_type='array', storage_transformers=()), 'b': ArrayV3Metadata(shape=(2, 2), - data_type=, + data_type=Float64(endianness='little'), chunk_grid=RegularChunkGrid(chunk_shape=(2, 2)), chunk_key_encoding=DefaultChunkKeyEncoding(name='default', separator='/'), @@ -73,7 +73,7 @@ that can be used.: node_type='array', storage_transformers=()), 'c': ArrayV3Metadata(shape=(3, 3, 3), - data_type=, + data_type=Float64(endianness='little'), chunk_grid=RegularChunkGrid(chunk_shape=(3, 3, 3)), chunk_key_encoding=DefaultChunkKeyEncoding(name='default', separator='/'), diff --git a/docs/user-guide/data_types.rst b/docs/user-guide/data_types.rst new file mode 100644 index 0000000000..87c8efc1f5 --- /dev/null +++ b/docs/user-guide/data_types.rst @@ -0,0 +1,172 @@ +Data types +========== + +Zarr's data type model +---------------------- + +Every Zarr array has a "data type", which defines the meaning and physical layout of the +array's elements. As Zarr Python is tightly integrated with `NumPy `_, +it's easy to create arrays with NumPy data types: + +.. code-block:: python + + >>> import zarr + >>> import numpy as np + >>> z = zarr.create_array(store={}, shape=(10,), dtype=np.dtype('uint8')) + >>> z + + +Unlike NumPy arrays, Zarr arrays are designed to accessed by Zarr +implementations in different programming languages. This means Zarr data types must be interpreted +correctly when clients read an array. Each Zarr data type defines procedures for +encoding and decoding both the data type itself, and scalars from that data type to and from Zarr array metadata. And these serialization procedures +depend on the Zarr format. + +Data types in Zarr version 2 +----------------------------- + +Version 2 of the Zarr format defined its data types relative to +`NumPy's data types `_, +and added a few non-NumPy data types as well. Thus the JSON identifier for a NumPy-compatible data +type is just the NumPy ``str`` attribute of that data type: + +.. code-block:: python + + >>> import zarr + >>> import numpy as np + >>> import json + >>> + >>> store = {} + >>> np_dtype = np.dtype('int64') + >>> z = zarr.create_array(store=store, shape=(1,), dtype=np_dtype, zarr_format=2) + >>> dtype_meta = json.loads(store['.zarray'].to_bytes())["dtype"] + >>> dtype_meta + '>> assert dtype_meta == np_dtype.str + +.. note:: + The ``<`` character in the data type metadata encodes the + `endianness `_, + or "byte order", of the data type. Following NumPy's example, + in Zarr version 2 each data type has an endianness where applicable. + However, Zarr version 3 data types do not store endianness information. + +In addition to defining a representation of the data type itself (which in the example above was +just a simple string ``"M[10s]"`` in + Zarr V2. This is more compact, but can be harder to parse. + +For more about data types in Zarr V3, see the +`V3 specification `_. + +Data types in Zarr Python +------------------------- + +The two Zarr formats that Zarr Python supports specify data types in two different ways: +data types in Zarr version 2 are encoded as NumPy-compatible strings, while data types in Zarr version +3 are encoded as either strings or ``JSON`` objects, +and the Zarr V3 data types don't have any associated endianness information, unlike Zarr V2 data types. + +To abstract over these syntactical and semantic differences, Zarr Python uses a class called +`ZDType <../api/zarr/dtype/index.html#zarr.dtype.ZDType>`_ provide Zarr V2 and Zarr V3 compatibility +routines for ""native" data types. In this context, a "native" data type is a Python class, +typically defined in another library, that models an array's data type. For example, ``np.uint8`` is a native +data type defined in NumPy, which Zarr Python wraps with a ``ZDType`` instance called +`UInt8 <../api/zarr/dtype/index.html#zarr.dtype.ZDType>`_. + +Each data type supported by Zarr Python is modeled by ``ZDType`` subclass, which provides an +API for the following operations: + +- Wrapping / unwrapping a native data type +- Encoding / decoding a data type to / from Zarr V2 and Zarr V3 array metadata. +- Encoding / decoding a scalar value to / from Zarr V2 and Zarr V3 array metadata. + + +Example Usage +~~~~~~~~~~~~~ + +Create a ``ZDType`` from a native data type: + +.. code-block:: python + + >>> from zarr.core.dtype import Int8 + >>> import numpy as np + >>> int8 = Int8.from_native_dtype(np.dtype('int8')) + +Convert back to native data type: + +.. code-block:: python + + >>> native_dtype = int8.to_native_dtype() + >>> assert native_dtype == np.dtype('int8') + +Get the default scalar value for the data type: + +.. code-block:: python + + >>> default_value = int8.default_scalar() + >>> assert default_value == np.int8(0) + + +Serialize to JSON for Zarr V2 and V3 + +.. code-block:: python + + >>> json_v2 = int8.to_json(zarr_format=2) + >>> json_v2 + {'name': '|i1', 'object_codec_id': None} + >>> json_v3 = int8.to_json(zarr_format=3) + >>> json_v3 + 'int8' + +Serialize a scalar value to JSON: + +.. code-block:: python + + >>> json_value = int8.to_json_scalar(42, zarr_format=3) + >>> json_value + 42 + +Deserialize a scalar value from JSON: + +.. code-block:: python + + >>> scalar_value = int8.from_json_scalar(42, zarr_format=3) + >>> assert scalar_value == np.int8(42) diff --git a/docs/user-guide/groups.rst b/docs/user-guide/groups.rst index 99234bad4e..4237a9df50 100644 --- a/docs/user-guide/groups.rst +++ b/docs/user-guide/groups.rst @@ -128,7 +128,7 @@ property. E.g.:: >>> bar.info_complete() Type : Array Zarr format : 3 - Data type : DataType.int64 + Data type : Int64(endianness='little') Fill value : 0 Shape : (1000000,) Chunk shape : (100000,) @@ -145,7 +145,7 @@ property. E.g.:: >>> baz.info Type : Array Zarr format : 3 - Data type : DataType.float32 + Data type : Float32(endianness='little') Fill value : 0.0 Shape : (1000, 1000) Chunk shape : (100, 100) diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index c50713332b..ea34ac2561 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -8,6 +8,7 @@ User guide installation arrays + data_types groups attributes storage diff --git a/docs/user-guide/performance.rst b/docs/user-guide/performance.rst index 88329f11b8..7d24c87373 100644 --- a/docs/user-guide/performance.rst +++ b/docs/user-guide/performance.rst @@ -91,7 +91,7 @@ To use sharding, you need to specify the ``shards`` parameter when creating the >>> z6.info Type : Array Zarr format : 3 - Data type : DataType.uint8 + Data type : UInt8() Fill value : 0 Shape : (10000, 10000, 1000) Shard shape : (1000, 1000, 1000) @@ -100,7 +100,7 @@ To use sharding, you need to specify the ``shards`` parameter when creating the Read-only : False Store type : MemoryStore Filters : () - Serializer : BytesCodec(endian=) + Serializer : BytesCodec(endian=None) Compressors : (ZstdCodec(level=0, checksum=False),) No. bytes : 100000000000 (93.1G) @@ -122,7 +122,7 @@ ratios, depending on the correlation structure within the data. E.g.:: >>> c.info_complete() Type : Array Zarr format : 3 - Data type : DataType.int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (10000, 10000) Chunk shape : (1000, 1000) @@ -142,7 +142,7 @@ ratios, depending on the correlation structure within the data. E.g.:: >>> f.info_complete() Type : Array Zarr format : 3 - Data type : DataType.int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (10000, 10000) Chunk shape : (1000, 1000) diff --git a/pyproject.toml b/pyproject.toml index 8141374d5e..2680396e7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -285,6 +285,7 @@ extend-exclude = [ "notebooks", # temporary, until we achieve compatibility with ruff ≥ 0.6 "venv", "docs", + "tests/test_regression/scripts/", # these are scripts that use a different version of python "src/zarr/v2/", "tests/v2/", ] @@ -355,7 +356,6 @@ strict = true warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] - [[tool.mypy.overrides]] module = [ "tests.package_with_entrypoint.*", @@ -385,6 +385,7 @@ module = [ "tests.test_properties", "tests.test_sync", "tests.test_v2", + "tests.test_regression.scripts.*" ] ignore_errors = true diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 16400f5f4b..d9e3520d42 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer @@ -12,11 +12,10 @@ from collections.abc import Awaitable, Callable, Iterable from typing import Self - import numpy as np - from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType from zarr.core.indexing import SelectorTuple __all__ = [ @@ -93,7 +92,13 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: """ return self - def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: + def validate( + self, + *, + shape: ChunkCoords, + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: """Validates that the codec configuration is compatible with the array metadata. Raises errors when the codec configuration is not compatible. @@ -285,7 +290,9 @@ def supports_partial_decode(self) -> bool: ... def supports_partial_encode(self) -> bool: ... @abstractmethod - def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: + def validate( + self, *, shape: ChunkCoords, dtype: ZDType[TBaseDType, TBaseScalar], chunk_grid: ChunkGrid + ) -> None: """Validates that all codec configurations are compatible with the array metadata. Raises errors when a codec configuration is not compatible. diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 54bddd80a8..3b53095636 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -14,6 +14,7 @@ Array, AsyncArray, CompressorLike, + _get_default_chunk_encoding_v2, create_array, from_array, get_array_metadata, @@ -30,8 +31,8 @@ _default_zarr_format, _warn_order_kwarg, _warn_write_empty_chunks_kwarg, - parse_dtype, ) +from zarr.core.dtype import ZDTypeLike, get_data_type_from_native_dtype, parse_data_type from zarr.core.group import ( AsyncGroup, ConsolidatedMetadata, @@ -39,7 +40,6 @@ create_hierarchy, ) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata -from zarr.core.metadata.v2 import _default_compressor, _default_filters from zarr.errors import GroupNotFoundError, NodeTypeValidationError from zarr.storage import StorePath from zarr.storage._common import make_store_path @@ -239,7 +239,6 @@ async def consolidate_metadata( group, metadata=metadata, ) - await group._save_metadata() return group @@ -457,11 +456,12 @@ async def save_array( shape = arr.shape chunks = getattr(arr, "chunks", None) # for array-likes with chunks attribute overwrite = kwargs.pop("overwrite", None) or _infer_overwrite(mode) + zarr_dtype = get_data_type_from_native_dtype(arr.dtype) new = await AsyncArray._create( store_path, zarr_format=zarr_format, shape=shape, - dtype=arr.dtype, + dtype=zarr_dtype, chunks=chunks, overwrite=overwrite, **kwargs, @@ -861,7 +861,7 @@ async def create( shape: ChunkCoords | int, *, # Note: this is a change from v2 chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True - dtype: npt.DTypeLike | None = None, + dtype: ZDTypeLike | None = None, compressor: CompressorLike = "auto", fill_value: Any | None = 0, # TODO: need type order: MemoryOrder | None = None, @@ -1008,13 +1008,13 @@ async def create( _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) or _default_zarr_format() ) - + zdtype = parse_data_type(dtype, zarr_format=zarr_format) if zarr_format == 2: - dtype = parse_dtype(dtype, zarr_format) + default_filters, default_compressor = _get_default_chunk_encoding_v2(zdtype) if not filters: - filters = _default_filters(dtype) + filters = default_filters # type: ignore[assignment] if compressor == "auto": - compressor = _default_compressor(dtype) + compressor = default_compressor if synchronizer is not None: warnings.warn("synchronizer is not yet implemented", RuntimeWarning, stacklevel=2) @@ -1066,7 +1066,7 @@ async def create( store_path, shape=shape, chunks=chunks, - dtype=dtype, + dtype=zdtype, compressor=compressor, fill_value=fill_value, overwrite=overwrite, diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index a7f7cfda35..f2dc8757d6 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -38,6 +38,7 @@ ShapeLike, ZarrFormat, ) + from zarr.core.dtype import ZDTypeLike from zarr.storage import StoreLike __all__ = [ @@ -603,9 +604,9 @@ def create( shape: ChunkCoords | int, *, # Note: this is a change from v2 chunks: ChunkCoords | int | bool | None = None, - dtype: npt.DTypeLike | None = None, + dtype: ZDTypeLike | None = None, compressor: CompressorLike = "auto", - fill_value: Any | None = 0, # TODO: need type + fill_value: Any | None = None, # TODO: need type order: MemoryOrder | None = None, store: str | StoreLike | None = None, synchronizer: Any | None = None, @@ -755,7 +756,7 @@ def create_array( *, name: str | None = None, shape: ShapeLike | None = None, - dtype: npt.DTypeLike | None = None, + dtype: ZDTypeLike | None = None, data: np.ndarray[Any, np.dtype[Any]] | None = None, chunks: ChunkCoords | Literal["auto"] = "auto", shards: ShardsLike | None = None, @@ -786,7 +787,7 @@ def create_array( at the root of the store. shape : ChunkCoords, optional Shape of the array. Can be ``None`` if ``data`` is provided. - dtype : npt.DTypeLike, optional + dtype : ZDTypeLike, optional Data type of the array. Can be ``None`` if ``data`` is provided. data : np.ndarray, optional Array-like data to use for initializing the array. If this parameter is provided, the diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index 53edc1f4a1..08853f27f1 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -46,9 +46,9 @@ async def _decode_single( chunk = ensure_ndarray_like(chunk) # special case object dtype, because incorrect handling can lead to # segfaults and other bad things happening - if chunk_spec.dtype != object: + if chunk_spec.dtype.dtype_cls is not np.dtypes.ObjectDType: try: - chunk = chunk.view(chunk_spec.dtype) + chunk = chunk.view(chunk_spec.dtype.to_native_dtype()) except TypeError: # this will happen if the dtype of the chunk # does not match the dtype of the array spec i.g. if @@ -56,7 +56,7 @@ async def _decode_single( # is an object array. In this case, we need to convert the object # array to the correct dtype. - chunk = np.array(chunk).astype(chunk_spec.dtype) + chunk = np.array(chunk).astype(chunk_spec.dtype.to_native_dtype()) elif chunk.dtype != object: # If we end up here, someone must have hacked around with the filters. @@ -80,7 +80,7 @@ async def _encode_single( chunk = chunk_array.as_ndarray_like() # ensure contiguous and correct order - chunk = chunk.astype(chunk_spec.dtype, order=chunk_spec.order, copy=False) + chunk = chunk.astype(chunk_spec.dtype.to_native_dtype(), order=chunk_spec.order, copy=False) # apply filters if self.filters: diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 9a999e10d7..1c5e52e9a4 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -13,6 +13,7 @@ from zarr.abc.codec import BytesBytesCodec from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_enum, parse_named_configuration +from zarr.core.dtype.common import HasItemSize from zarr.registry import register_codec if TYPE_CHECKING: @@ -137,14 +138,16 @@ def to_dict(self) -> dict[str, JSON]: } def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: - dtype = array_spec.dtype + item_size = 1 + if isinstance(array_spec.dtype, HasItemSize): + item_size = array_spec.dtype.item_size new_codec = self if new_codec.typesize is None: - new_codec = replace(new_codec, typesize=dtype.itemsize) + new_codec = replace(new_codec, typesize=item_size) if new_codec.shuffle is None: new_codec = replace( new_codec, - shuffle=(BloscShuffle.bitshuffle if dtype.itemsize == 1 else BloscShuffle.shuffle), + shuffle=(BloscShuffle.bitshuffle if item_size == 1 else BloscShuffle.shuffle), ) return new_codec diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 750707d36a..d663a3b2cc 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -10,6 +10,7 @@ from zarr.abc.codec import ArrayBytesCodec from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer from zarr.core.common import JSON, parse_enum, parse_named_configuration +from zarr.core.dtype.common import HasEndianness from zarr.registry import register_codec if TYPE_CHECKING: @@ -56,7 +57,7 @@ def to_dict(self) -> dict[str, JSON]: return {"name": "bytes", "configuration": {"endian": self.endian.value}} def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: - if array_spec.dtype.itemsize == 0: + if not isinstance(array_spec.dtype, HasEndianness): if self.endian is not None: return replace(self, endian=None) elif self.endian is None: @@ -71,15 +72,12 @@ async def _decode_single( chunk_spec: ArraySpec, ) -> NDBuffer: assert isinstance(chunk_bytes, Buffer) - if chunk_spec.dtype.itemsize > 0: - if self.endian == Endian.little: - prefix = "<" - else: - prefix = ">" - dtype = np.dtype(f"{prefix}{chunk_spec.dtype.str[1:]}") + # TODO: remove endianness enum in favor of literal union + endian_str = self.endian.value if self.endian is not None else None + if isinstance(chunk_spec.dtype, HasEndianness): + dtype = replace(chunk_spec.dtype, endianness=endian_str).to_native_dtype() # type: ignore[call-arg] else: - dtype = np.dtype(f"|{chunk_spec.dtype.str[1:]}") - + dtype = chunk_spec.dtype.to_native_dtype() as_array_like = chunk_bytes.as_array_like() if isinstance(as_array_like, NDArrayLike): as_nd_array_like = as_array_like diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 4638d973cb..cd8676b4d1 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -43,6 +43,7 @@ parse_shapelike, product, ) +from zarr.core.dtype.npy.int import UInt64 from zarr.core.indexing import ( BasicIndexer, SelectorTuple, @@ -58,6 +59,7 @@ from typing import Self from zarr.core.common import JSON + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType MAX_UINT_64 = 2**64 - 1 ShardMapping = Mapping[ChunkCoords, Buffer] @@ -355,7 +357,11 @@ def __init__( object.__setattr__(self, "index_location", index_location_parsed) # Use instance-local lru_cache to avoid memory leaks - object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) + + # numpy void scalars are not hashable, which means an array spec with a fill value that is + # a numpy void scalar will break the lru_cache. This is commented for now but should be + # fixed. See https://github.com/zarr-developers/zarr-python/issues/3054 + # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) @@ -371,7 +377,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__(self, "index_location", parse_index_location(config["index_location"])) # Use instance-local lru_cache to avoid memory leaks - object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) + # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) @@ -402,7 +408,13 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return replace(self, codecs=evolved_codecs) return self - def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: + def validate( + self, + *, + shape: ChunkCoords, + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: if len(self.chunk_shape) != len(shape): raise ValueError( "The shard's `chunk_shape` and array's `shape` need to have the same number of dimensions." @@ -439,7 +451,10 @@ async def _decode_single( # setup output array out = chunk_spec.prototype.nd_buffer.create( - shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 + shape=shard_shape, + dtype=shard_spec.dtype.to_native_dtype(), + order=shard_spec.order, + fill_value=0, ) shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) @@ -483,7 +498,10 @@ async def _decode_partial_single( # setup output array out = shard_spec.prototype.nd_buffer.create( - shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 + shape=indexer.shape, + dtype=shard_spec.dtype.to_native_dtype(), + order=shard_spec.order, + fill_value=0, ) indexed_chunks = list(indexer) @@ -678,7 +696,7 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: return ArraySpec( shape=chunks_per_shard + (2,), - dtype=np.dtype(" tuple[int, ...]: @@ -45,7 +46,12 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "transpose", "configuration": {"order": tuple(self.order)}} - def validate(self, shape: tuple[int, ...], dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: + def validate( + self, + shape: tuple[int, ...], + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: if len(self.order) != len(shape): raise ValueError( f"The `order` tuple needs have as many entries as there are dimensions in the array. Got {self.order}." diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index 0ef423793d..b7c0418b2e 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -10,7 +10,6 @@ from zarr.abc.codec import ArrayBytesCodec from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import JSON, parse_named_configuration -from zarr.core.strings import cast_to_string_dtype from zarr.registry import register_codec if TYPE_CHECKING: @@ -49,6 +48,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self + # TODO: expand the tests for this function async def _decode_single( self, chunk_bytes: Buffer, @@ -60,8 +60,7 @@ async def _decode_single( decoded = _vlen_utf8_codec.decode(raw_bytes) assert decoded.dtype == np.object_ decoded.shape = chunk_spec.shape - # coming out of the code, we know this is safe, so don't issue a warning - as_string_dtype = cast_to_string_dtype(decoded, safe=True) + as_string_dtype = decoded.astype(chunk_spec.dtype.to_native_dtype(), copy=False) return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) async def _encode_single( diff --git a/src/zarr/core/_info.py b/src/zarr/core/_info.py index ee953d4591..d57d17f934 100644 --- a/src/zarr/core/_info.py +++ b/src/zarr/core/_info.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import dataclasses import textwrap -from typing import Any, Literal +from typing import TYPE_CHECKING, Literal -import numcodecs.abc -import numpy as np +if TYPE_CHECKING: + import numcodecs.abc -from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec -from zarr.core.common import ZarrFormat -from zarr.core.metadata.v3 import DataType + from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec + from zarr.core.common import ZarrFormat + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @dataclasses.dataclass(kw_only=True) @@ -78,7 +80,7 @@ class ArrayInfo: _type: Literal["Array"] = "Array" _zarr_format: ZarrFormat - _data_type: np.dtype[Any] | DataType + _data_type: ZDType[TBaseDType, TBaseScalar] _fill_value: object _shape: tuple[int, ...] _shard_shape: tuple[int, ...] | None = None diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index b4e8ac0ff6..cd6b33a28c 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -22,7 +22,6 @@ import numcodecs import numcodecs.abc import numpy as np -import numpy.typing as npt from typing_extensions import deprecated import zarr @@ -30,6 +29,7 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.store import Store, set_or_delete from zarr.codecs._v2 import V2Codec +from zarr.codecs.bytes import BytesCodec from zarr.core._info import ArrayInfo from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, parse_array_config from zarr.core.attributes import Attributes @@ -61,12 +61,18 @@ _default_zarr_format, _warn_order_kwarg, concurrent_map, - parse_dtype, parse_order, parse_shapelike, product, ) +from zarr.core.config import categorize_data_type from zarr.core.config import config as zarr_config +from zarr.core.dtype import ( + ZDType, + ZDTypeLike, + parse_data_type, +) +from zarr.core.dtype.common import HasEndianness, HasItemSize from zarr.core.indexing import ( BasicIndexer, BasicSelection, @@ -103,12 +109,10 @@ ) from zarr.core.metadata.v2 import ( CompressorLikev2, - _default_compressor, - _default_filters, parse_compressor, parse_filters, ) -from zarr.core.metadata.v3 import DataType, parse_node_type_array +from zarr.core.metadata.v3 import parse_node_type_array from zarr.core.sync import sync from zarr.errors import MetadataValidationError from zarr.registry import ( @@ -124,8 +128,11 @@ from collections.abc import Iterator, Sequence from typing import Self + import numpy.typing as npt + from zarr.abc.codec import CodecPipeline from zarr.codecs.sharding import ShardingCodecIndexLocation + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar from zarr.core.group import AsyncGroup from zarr.storage import StoreLike @@ -287,7 +294,7 @@ async def create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: Literal[2], fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -311,7 +318,7 @@ async def create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: Literal[3], fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -339,7 +346,7 @@ async def create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: Literal[3] = 3, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -367,7 +374,7 @@ async def create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: ZarrFormat, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -402,7 +409,7 @@ async def create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: ZarrFormat = 3, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -438,7 +445,7 @@ async def create( The store where the array will be created. shape : ShapeLike The shape of the array. - dtype : npt.DTypeLike + dtype : ZDTypeLike The data type of the array. zarr_format : ZarrFormat, optional The Zarr format version (default is 3). @@ -543,7 +550,7 @@ async def _create( *, # v2 and v3 shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike | ZDType[TBaseDType, TBaseScalar], zarr_format: ZarrFormat = 3, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -572,18 +579,21 @@ async def _create( See :func:`AsyncArray.create` for more details. Deprecated in favor of :func:`zarr.api.asynchronous.create_array`. """ + + dtype_parsed = parse_data_type(dtype, zarr_format=zarr_format) store_path = await make_store_path(store) - dtype_parsed = parse_dtype(dtype, zarr_format) shape = parse_shapelike(shape) if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") - + item_size = 1 + if isinstance(dtype_parsed, HasItemSize): + item_size = dtype_parsed.item_size if chunks: - _chunks = normalize_chunks(chunks, shape, dtype_parsed.itemsize) + _chunks = normalize_chunks(chunks, shape, item_size) else: - _chunks = normalize_chunks(chunk_shape, shape, dtype_parsed.itemsize) + _chunks = normalize_chunks(chunk_shape, shape, item_size) config_parsed = parse_array_config(config) result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] @@ -661,7 +671,7 @@ async def _create( @staticmethod def _create_metadata_v3( shape: ShapeLike, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], chunk_shape: ChunkCoords, fill_value: Any | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, @@ -672,30 +682,36 @@ def _create_metadata_v3( """ Create an instance of ArrayV3Metadata. """ + filters: tuple[ArrayArrayCodec, ...] + compressors: tuple[BytesBytesCodec, ...] shape = parse_shapelike(shape) - codecs = list(codecs) if codecs is not None else _get_default_codecs(np.dtype(dtype)) + if codecs is None: + filters, serializer, compressors = _get_default_chunk_encoding_v3(dtype) + codecs_parsed = (*filters, serializer, *compressors) + else: + codecs_parsed = tuple(codecs) + chunk_key_encoding_parsed: ChunkKeyEncodingLike if chunk_key_encoding is None: chunk_key_encoding_parsed = {"name": "default", "separator": "/"} else: chunk_key_encoding_parsed = chunk_key_encoding - if dtype.kind in "UTS": - warn( - f"The dtype `{dtype}` is currently not part in the Zarr format 3 specification. It " - "may not be supported by other zarr implementations and may change in the future.", - category=UserWarning, - stacklevel=2, - ) + if fill_value is None: + # v3 spec will not allow a null fill value + fill_value_parsed = dtype.default_scalar() + else: + fill_value_parsed = fill_value + chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape) return ArrayV3Metadata( shape=shape, data_type=dtype, chunk_grid=chunk_grid_parsed, chunk_key_encoding=chunk_key_encoding_parsed, - fill_value=fill_value, - codecs=codecs, + fill_value=fill_value_parsed, + codecs=codecs_parsed, # type: ignore[arg-type] dimension_names=tuple(dimension_names) if dimension_names else None, attributes=attributes or {}, ) @@ -706,7 +722,7 @@ async def _create_v3( store_path: StorePath, *, shape: ShapeLike, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], chunk_shape: ChunkCoords, config: ArrayConfig, fill_value: Any | None = None, @@ -754,7 +770,7 @@ async def _create_v3( @staticmethod def _create_metadata_v2( shape: ChunkCoords, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], chunks: ChunkCoords, order: MemoryOrder, dimension_separator: Literal[".", "/"] | None = None, @@ -765,12 +781,11 @@ def _create_metadata_v2( ) -> ArrayV2Metadata: if dimension_separator is None: dimension_separator = "." - - dtype = parse_dtype(dtype, zarr_format=2) - + if fill_value is None: + fill_value = dtype.default_scalar() # type: ignore[assignment] return ArrayV2Metadata( shape=shape, - dtype=np.dtype(dtype), + dtype=dtype, chunks=chunks, order=order, dimension_separator=dimension_separator, @@ -786,7 +801,7 @@ async def _create_v2( store_path: StorePath, *, shape: ChunkCoords, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], chunks: ChunkCoords, order: MemoryOrder, config: ArrayConfig, @@ -807,7 +822,7 @@ async def _create_v2( compressor_parsed: CompressorLikev2 if compressor == "auto": - compressor_parsed = _default_compressor(dtype) + _, compressor_parsed = _get_default_chunk_encoding_v2(dtype) elif isinstance(compressor, BytesBytesCodec): raise ValueError( "Cannot use a BytesBytesCodec as a compressor for zarr v2 arrays. " @@ -1023,7 +1038,17 @@ def compressors(self) -> tuple[numcodecs.abc.Codec, ...] | tuple[BytesBytesCodec ) @property - def dtype(self) -> np.dtype[Any]: + def _zdtype(self) -> ZDType[TBaseDType, TBaseScalar]: + """ + The zarr-specific representation of the array data type + """ + if self.metadata.zarr_format == 2: + return self.metadata.dtype + else: + return self.metadata.data_type + + @property + def dtype(self) -> TBaseDType: """Returns the data type of the array. Returns @@ -1031,7 +1056,7 @@ def dtype(self) -> np.dtype[Any]: np.dtype Data type of the array """ - return self.metadata.dtype + return self._zdtype.to_native_dtype() @property def order(self) -> MemoryOrder: @@ -1392,20 +1417,21 @@ async def _set_selection( # TODO: need to handle array types that don't support __array_function__ # like PyTorch and JAX array_like_ = cast("np._typing._SupportsArrayFunc", array_like) - value = np.asanyarray(value, dtype=self.metadata.dtype, like=array_like_) + value = np.asanyarray(value, dtype=self.dtype, like=array_like_) else: if not hasattr(value, "shape"): - value = np.asarray(value, self.metadata.dtype) + value = np.asarray(value, self.dtype) # assert ( # value.shape == indexer.shape # ), f"shape of value doesn't match indexer shape. Expected {indexer.shape}, got {value.shape}" - if not hasattr(value, "dtype") or value.dtype.name != self.metadata.dtype.name: + if not hasattr(value, "dtype") or value.dtype.name != self.dtype.name: if hasattr(value, "astype"): # Handle things that are already NDArrayLike more efficiently - value = value.astype(dtype=self.metadata.dtype, order="A") + value = value.astype(dtype=self.dtype, order="A") else: - value = np.array(value, dtype=self.metadata.dtype, order="A") + value = np.array(value, dtype=self.dtype, order="A") value = cast("NDArrayLike", value) + # We accept any ndarray like object from the user and convert it # to a NDBuffer (or subclass). From this point onwards, we only pass # Buffer and NDBuffer between components. @@ -1685,15 +1711,9 @@ async def info_complete(self) -> Any: def _info( self, count_chunks_initialized: int | None = None, count_bytes_stored: int | None = None ) -> Any: - _data_type: np.dtype[Any] | DataType - if isinstance(self.metadata, ArrayV2Metadata): - _data_type = self.metadata.dtype - else: - _data_type = self.metadata.data_type - return ArrayInfo( _zarr_format=self.metadata.zarr_format, - _data_type=_data_type, + _data_type=self._zdtype, _fill_value=self.metadata.fill_value, _shape=self.shape, _order=self.order, @@ -1728,7 +1748,7 @@ def create( *, # v2 and v3 shape: ChunkCoords, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: ZarrFormat = 3, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -1763,7 +1783,7 @@ def create( The array store that has already been initialized. shape : ChunkCoords The shape of the array. - dtype : npt.DTypeLike + dtype : ZDTypeLike The data type of the array. chunk_shape : ChunkCoords, optional The shape of the Array's chunks. @@ -1857,7 +1877,7 @@ def _create( *, # v2 and v3 shape: ChunkCoords, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, zarr_format: ZarrFormat = 3, fill_value: Any | None = None, attributes: dict[str, JSON] | None = None, @@ -3773,13 +3793,6 @@ def _build_parents( return parents -def _get_default_codecs( - np_dtype: np.dtype[Any], -) -> tuple[Codec, ...]: - filters, serializer, compressors = _get_default_chunk_encoding_v3(np_dtype) - return filters + (serializer,) + compressors - - FiltersLike: TypeAlias = ( Iterable[dict[str, JSON] | ArrayArrayCodec | numcodecs.abc.Codec] | ArrayArrayCodec @@ -4079,7 +4092,7 @@ async def init_array( *, store_path: StorePath, shape: ShapeLike, - dtype: npt.DTypeLike, + dtype: ZDTypeLike, chunks: ChunkCoords | Literal["auto"] = "auto", shards: ShardsLike | None = None, filters: FiltersLike = "auto", @@ -4102,7 +4115,7 @@ async def init_array( StorePath instance. The path attribute is the name of the array to initialize. shape : ChunkCoords Shape of the array. - dtype : npt.DTypeLike + dtype : ZDTypeLike Data type of the array. chunks : ChunkCoords, optional Chunk shape of the array. @@ -4186,7 +4199,7 @@ async def init_array( from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation - dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format) + zdtype = parse_data_type(dtype, zarr_format=zarr_format) shape_parsed = parse_shapelike(shape) chunk_key_encoding_parsed = _parse_chunk_key_encoding( chunk_key_encoding, zarr_format=zarr_format @@ -4200,8 +4213,15 @@ async def init_array( else: await ensure_no_existing_node(store_path, zarr_format=zarr_format) + item_size = 1 + if isinstance(zdtype, HasItemSize): + item_size = zdtype.item_size + shard_shape_parsed, chunk_shape_parsed = _auto_partition( - array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed + array_shape=shape_parsed, + shard_shape=shards, + chunk_shape=chunks, + item_size=item_size, ) chunks_out: tuple[int, ...] meta: ArrayV2Metadata | ArrayV3Metadata @@ -4217,9 +4237,8 @@ async def init_array( raise ValueError("Zarr format 2 arrays do not support `serializer`.") filters_parsed, compressor_parsed = _parse_chunk_encoding_v2( - compressor=compressors, filters=filters, dtype=np.dtype(dtype) + compressor=compressors, filters=filters, dtype=zdtype ) - if dimension_names is not None: raise ValueError("Zarr format 2 arrays do not support dimension names.") if order is None: @@ -4229,7 +4248,7 @@ async def init_array( meta = AsyncArray._create_metadata_v2( shape=shape_parsed, - dtype=dtype_parsed, + dtype=zdtype, chunks=chunk_shape_parsed, dimension_separator=chunk_key_encoding_parsed.separator, fill_value=fill_value, @@ -4243,7 +4262,7 @@ async def init_array( compressors=compressors, filters=filters, serializer=serializer, - dtype=dtype_parsed, + dtype=zdtype, ) sub_codecs = cast("tuple[Codec, ...]", (*array_array, array_bytes, *bytes_bytes)) codecs_out: tuple[Codec, ...] @@ -4258,7 +4277,7 @@ async def init_array( ) sharding_codec.validate( shape=chunk_shape_parsed, - dtype=dtype_parsed, + dtype=zdtype, chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed), ) codecs_out = (sharding_codec,) @@ -4274,7 +4293,7 @@ async def init_array( meta = AsyncArray._create_metadata_v3( shape=shape_parsed, - dtype=dtype_parsed, + dtype=zdtype, fill_value=fill_value, chunk_shape=chunks_out, chunk_key_encoding=chunk_key_encoding_parsed, @@ -4293,7 +4312,7 @@ async def create_array( *, name: str | None = None, shape: ShapeLike | None = None, - dtype: npt.DTypeLike | None = None, + dtype: ZDTypeLike | None = None, data: np.ndarray[Any, np.dtype[Any]] | None = None, chunks: ChunkCoords | Literal["auto"] = "auto", shards: ShardsLike | None = None, @@ -4322,7 +4341,7 @@ async def create_array( at the root of the store. shape : ChunkCoords, optional Shape of the array. Can be ``None`` if ``data`` is provided. - dtype : npt.DTypeLike | None + dtype : ZDTypeLike | None Data type of the array. Can be ``None`` if ``data`` is provided. data : Array-like data to use for initializing the array. If this parameter is provided, the ``shape`` and ``dtype`` parameters must be identical to ``data.shape`` and ``data.dtype``, @@ -4582,62 +4601,50 @@ def _parse_chunk_key_encoding( def _get_default_chunk_encoding_v3( - np_dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], ) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]: """ Get the default ArrayArrayCodecs, ArrayBytesCodec, and BytesBytesCodec for a given dtype. """ - dtype = DataType.from_numpy(np_dtype) - if dtype == DataType.string: - dtype_key = "string" - elif dtype == DataType.bytes: - dtype_key = "bytes" - else: - dtype_key = "numeric" - default_filters = zarr_config.get("array.v3_default_filters").get(dtype_key) - default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype_key) - default_compressors = zarr_config.get("array.v3_default_compressors").get(dtype_key) + dtype_category = categorize_data_type(dtype) - filters = tuple(_parse_array_array_codec(codec_dict) for codec_dict in default_filters) - serializer = _parse_array_bytes_codec(default_serializer) - compressors = tuple(_parse_bytes_bytes_codec(codec_dict) for codec_dict in default_compressors) + filters = zarr_config.get("array.v3_default_filters").get(dtype_category) + compressors = zarr_config.get("array.v3_default_compressors").get(dtype_category) + serializer = zarr_config.get("array.v3_default_serializer").get(dtype_category) - return filters, serializer, compressors + return ( + tuple(_parse_array_array_codec(f) for f in filters), + _parse_array_bytes_codec(serializer), + tuple(_parse_bytes_bytes_codec(c) for c in compressors), + ) def _get_default_chunk_encoding_v2( - np_dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], ) -> tuple[tuple[numcodecs.abc.Codec, ...] | None, numcodecs.abc.Codec | None]: """ Get the default chunk encoding for Zarr format 2 arrays, given a dtype """ + dtype_category = categorize_data_type(dtype) + filters = zarr_config.get("array.v2_default_filters").get(dtype_category) + compressor = zarr_config.get("array.v2_default_compressor").get(dtype_category) + if filters is not None: + filters = tuple(numcodecs.get_codec(f) for f in filters) - compressor_dict = _default_compressor(np_dtype) - filter_dicts = _default_filters(np_dtype) - - compressor = None - if compressor_dict is not None: - compressor = numcodecs.get_codec(compressor_dict) - - filters = None - if filter_dicts is not None: - filters = tuple(numcodecs.get_codec(f) for f in filter_dicts) - - return filters, compressor + return filters, numcodecs.get_codec(compressor) def _parse_chunk_encoding_v2( *, compressor: CompressorsLike, filters: FiltersLike, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], ) -> tuple[tuple[numcodecs.abc.Codec, ...] | None, numcodecs.abc.Codec | None]: """ Generate chunk encoding classes for Zarr format 2 arrays with optional defaults. """ default_filters, default_compressor = _get_default_chunk_encoding_v2(dtype) - _filters: tuple[numcodecs.abc.Codec, ...] | None _compressor: numcodecs.abc.Codec | None @@ -4676,7 +4683,7 @@ def _parse_chunk_encoding_v3( compressors: CompressorsLike, filters: FiltersLike, serializer: SerializerLike, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], ) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]: """ Generate chunk encoding classes for v3 arrays with optional defaults. @@ -4700,6 +4707,9 @@ def _parse_chunk_encoding_v3( if serializer == "auto": out_array_bytes = default_array_bytes else: + # TODO: ensure that the serializer is compatible with the ndarray produced by the + # array-array codecs. For example, if a sequence of array-array codecs produces an + # array with a single-byte data type, then the serializer should not specify endiannesss. out_array_bytes = _parse_array_bytes_codec(serializer) if compressors is None: @@ -4715,6 +4725,17 @@ def _parse_chunk_encoding_v3( out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c) for c in maybe_bytes_bytes) + # specialize codecs as needed given the dtype + + # TODO: refactor so that the config only contains the name of the codec, and we use the dtype + # to create the codec instance, instead of storing a dict representation of a full codec. + + # TODO: ensure that the serializer is compatible with the ndarray produced by the + # array-array codecs. For example, if a sequence of array-array codecs produces an + # array with a single-byte data type, then the serializer should not specify endiannesss. + if isinstance(out_array_bytes, BytesCodec) and not isinstance(dtype, HasEndianness): + # The default endianness in the bytescodec might not be None, so we need to replace it + out_array_bytes = replace(out_array_bytes, endian=None) return out_array_array, out_array_bytes, out_bytes_bytes @@ -4744,8 +4765,8 @@ def _parse_data_params( *, data: np.ndarray[Any, np.dtype[Any]] | None, shape: ShapeLike | None, - dtype: npt.DTypeLike | None, -) -> tuple[np.ndarray[Any, np.dtype[Any]] | None, ShapeLike, npt.DTypeLike]: + dtype: ZDTypeLike | None, +) -> tuple[np.ndarray[Any, np.dtype[Any]] | None, ShapeLike, ZDTypeLike]: """ Ensure an array-like ``data`` parameter is consistent with the ``dtype`` and ``shape`` parameters. diff --git a/src/zarr/core/array_spec.py b/src/zarr/core/array_spec.py index 6cd27b30eb..279bf6edf0 100644 --- a/src/zarr/core/array_spec.py +++ b/src/zarr/core/array_spec.py @@ -3,8 +3,6 @@ from dataclasses import dataclass, fields from typing import TYPE_CHECKING, Any, Literal, Self, TypedDict, cast -import numpy as np - from zarr.core.common import ( MemoryOrder, parse_bool, @@ -19,6 +17,7 @@ from zarr.core.buffer import BufferPrototype from zarr.core.common import ChunkCoords + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType class ArrayConfigParams(TypedDict): @@ -90,7 +89,7 @@ def parse_array_config(data: ArrayConfigLike | None) -> ArrayConfig: @dataclass(frozen=True) class ArraySpec: shape: ChunkCoords - dtype: np.dtype[Any] + dtype: ZDType[TBaseDType, TBaseScalar] fill_value: Any config: ArrayConfig prototype: BufferPrototype @@ -98,17 +97,16 @@ class ArraySpec: def __init__( self, shape: ChunkCoords, - dtype: np.dtype[Any], + dtype: ZDType[TBaseDType, TBaseScalar], fill_value: Any, config: ArrayConfig, prototype: BufferPrototype, ) -> None: shape_parsed = parse_shapelike(shape) - dtype_parsed = np.dtype(dtype) fill_value_parsed = parse_fill_value(fill_value) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "dtype", dtype_parsed) + object.__setattr__(self, "dtype", dtype) object.__setattr__(self, "fill_value", fill_value_parsed) object.__setattr__(self, "config", config) object.__setattr__(self, "prototype", prototype) diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index d0a2d992d2..0e24c5b326 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -495,7 +495,9 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool: return np.array_equal( self._data, other, - equal_nan=equal_nan if self._data.dtype.kind not in "USTOV" else False, + equal_nan=equal_nan + if self._data.dtype.kind not in ("U", "S", "T", "O", "V") + else False, ) def fill(self, value: Any) -> None: diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids.py index b5a581b8a4..4bf03c89de 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids.py @@ -207,7 +207,7 @@ def _auto_partition( array_shape: tuple[int, ...], chunk_shape: tuple[int, ...] | Literal["auto"], shard_shape: ShardsLike | None, - dtype: np.dtype[Any], + item_size: int, ) -> tuple[tuple[int, ...] | None, tuple[int, ...]]: """ Automatically determine the shard shape and chunk shape for an array, given the shape and dtype of the array. @@ -217,7 +217,6 @@ def _auto_partition( of the array; if the `chunk_shape` is also "auto", then the chunks will be set heuristically as well, given the dtype and shard shape. Otherwise, the chunks will be returned as-is. """ - item_size = dtype.itemsize if shard_shape is None: _shards_out: None | tuple[int, ...] = None if chunk_shape == "auto": diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 628a7e0487..23c27e40c6 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -17,19 +17,17 @@ from zarr.core.common import ChunkCoords, concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar -from zarr.core.metadata.v2 import _default_fill_value from zarr.registry import register_pipeline if TYPE_CHECKING: from collections.abc import Iterable, Iterator from typing import Self - import numpy as np - from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType T = TypeVar("T") U = TypeVar("U") @@ -64,7 +62,7 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: # validated when decoding the metadata, but we support reading # Zarr V2 data and need to support the case where fill_value # is None. - return _default_fill_value(dtype=chunk_spec.dtype) + return chunk_spec.dtype.default_scalar() else: return fill_value @@ -134,7 +132,9 @@ def __iter__(self) -> Iterator[Codec]: yield self.array_bytes_codec yield from self.bytes_bytes_codecs - def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None: + def validate( + self, *, shape: ChunkCoords, dtype: ZDType[TBaseDType, TBaseScalar], chunk_grid: ChunkGrid + ) -> None: for codec in self: codec.validate(shape=shape, dtype=dtype, chunk_grid=chunk_grid) @@ -296,7 +296,9 @@ def _merge_chunk_array( is_complete_chunk: bool, drop_axes: tuple[int, ...], ) -> NDBuffer: - if chunk_selection == () or is_scalar(value.as_ndarray_like(), chunk_spec.dtype): + if chunk_selection == () or is_scalar( + value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype() + ): chunk_value = value else: chunk_value = value[out_selection] @@ -317,7 +319,7 @@ def _merge_chunk_array( if existing_chunk_array is None: chunk_array = chunk_spec.prototype.nd_buffer.create( shape=chunk_spec.shape, - dtype=chunk_spec.dtype, + dtype=chunk_spec.dtype.to_native_dtype(), order=chunk_spec.order, fill_value=fill_value_or_default(chunk_spec), ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index be37dc5109..2ba5914ea5 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -10,16 +10,15 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Literal, + TypedDict, TypeVar, cast, overload, ) -import numpy as np - from zarr.core.config import config as zarr_config -from zarr.core.strings import _STRING_DTYPE if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -42,6 +41,14 @@ AccessModeLiteral = Literal["r", "r+", "a", "w", "w-"] DimensionNames = Iterable[str | None] | None +TName = TypeVar("TName", bound=str) +TConfig = TypeVar("TConfig", bound=Mapping[str, object]) + + +class NamedConfig(TypedDict, Generic[TName, TConfig]): + name: TName + configuration: TConfig + def product(tup: ChunkCoords) -> int: return functools.reduce(operator.mul, tup, 1) @@ -168,16 +175,6 @@ def parse_bool(data: Any) -> bool: raise ValueError(f"Expected bool, got {data} instead.") -def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]: - if dtype is str or dtype == "str": - if zarr_format == 2: - # special case as object - return np.dtype("object") - else: - return _STRING_DTYPE - return np.dtype(dtype) - - def _warn_write_empty_chunks_kwarg() -> None: # TODO: link to docs page on array configuration in this message msg = ( diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 2a10943d80..74e9bdd8dd 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -36,11 +36,21 @@ if TYPE_CHECKING: from donfig.config_obj import ConfigSet + from zarr.core.dtype.wrapper import ZDType + class BadConfigError(ValueError): _msg = "bad Config: %r" +# These values are used for rough categorization of data types +# we use this for choosing a default encoding scheme based on the data type. Specifically, +# these categories are keys in a configuration dictionary. +# it is not a part of the ZDType class because these categories are more of an implementation detail +# of our config system rather than a useful attribute of any particular data type. +DTypeCategory = Literal["variable-length-string", "default"] + + class Config(DConfig): # type: ignore[misc] """The Config will collect configuration from config files and environment variables @@ -78,31 +88,24 @@ def enable_gpu(self) -> ConfigSet: "order": "C", "write_empty_chunks": False, "v2_default_compressor": { - "numeric": {"id": "zstd", "level": 0, "checksum": False}, - "string": {"id": "zstd", "level": 0, "checksum": False}, - "bytes": {"id": "zstd", "level": 0, "checksum": False}, + "default": {"id": "zstd", "level": 0, "checksum": False}, + "variable-length-string": {"id": "zstd", "level": 0, "checksum": False}, }, "v2_default_filters": { - "numeric": None, - "string": [{"id": "vlen-utf8"}], - "bytes": [{"id": "vlen-bytes"}], - "raw": None, + "default": None, + "variable-length-string": [{"id": "vlen-utf8"}], }, - "v3_default_filters": {"numeric": [], "string": [], "bytes": []}, + "v3_default_filters": {"default": [], "variable-length-string": []}, "v3_default_serializer": { - "numeric": {"name": "bytes", "configuration": {"endian": "little"}}, - "string": {"name": "vlen-utf8"}, - "bytes": {"name": "vlen-bytes"}, + "default": {"name": "bytes", "configuration": {"endian": "little"}}, + "variable-length-string": {"name": "vlen-utf8"}, }, "v3_default_compressors": { - "numeric": [ + "default": [ {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, ], - "string": [ - {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, - ], - "bytes": [ - {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, + "variable-length-string": [ + {"name": "zstd", "configuration": {"level": 0, "checksum": False}} ], }, }, @@ -137,3 +140,17 @@ def parse_indexing_order(data: Any) -> Literal["C", "F"]: return cast("Literal['C', 'F']", data) msg = f"Expected one of ('C', 'F'), got {data} instead." raise ValueError(msg) + + +def categorize_data_type(dtype: ZDType[Any, Any]) -> DTypeCategory: + """ + Classify a ZDType. The return value is a string which belongs to the type ``DTypeCategory``. + + This is used by the config system to determine how to encode arrays with the associated data type + when the user has not specified a particular serialization scheme. + """ + from zarr.core.dtype import VariableLengthUTF8 + + if isinstance(dtype, VariableLengthUTF8): + return "variable-length-string" + return "default" diff --git a/src/zarr/core/dtype/__init__.py b/src/zarr/core/dtype/__init__.py new file mode 100644 index 0000000000..735690d4bc --- /dev/null +++ b/src/zarr/core/dtype/__init__.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Final, TypeAlias + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeJSON, +) +from zarr.core.dtype.npy.bool import Bool +from zarr.core.dtype.npy.bytes import NullTerminatedBytes, RawBytes, VariableLengthBytes +from zarr.core.dtype.npy.complex import Complex64, Complex128 +from zarr.core.dtype.npy.float import Float16, Float32, Float64 +from zarr.core.dtype.npy.int import Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64 +from zarr.core.dtype.npy.structured import ( + Structured, +) +from zarr.core.dtype.npy.time import DateTime64, TimeDelta64 + +if TYPE_CHECKING: + from zarr.core.common import ZarrFormat + +from collections.abc import Mapping + +import numpy as np +import numpy.typing as npt + +from zarr.core.common import JSON +from zarr.core.dtype.npy.string import ( + FixedLengthUTF32, + VariableLengthUTF8, +) +from zarr.core.dtype.registry import DataTypeRegistry +from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + +__all__ = [ + "Bool", + "Complex64", + "Complex128", + "DataTypeRegistry", + "DataTypeValidationError", + "DateTime64", + "FixedLengthUTF32", + "Float16", + "Float32", + "Float64", + "Int8", + "Int16", + "Int32", + "Int64", + "NullTerminatedBytes", + "RawBytes", + "Structured", + "TBaseDType", + "TBaseScalar", + "TimeDelta64", + "TimeDelta64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "VariableLengthUTF8", + "ZDType", + "data_type_registry", + "parse_data_type", +] + +data_type_registry = DataTypeRegistry() + +IntegerDType = Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 +INTEGER_DTYPE: Final = Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64 + +FloatDType = Float16 | Float32 | Float64 +FLOAT_DTYPE: Final = Float16, Float32, Float64 + +ComplexFloatDType = Complex64 | Complex128 +COMPLEX_FLOAT_DTYPE: Final = Complex64, Complex128 + +StringDType = FixedLengthUTF32 | VariableLengthUTF8 +STRING_DTYPE: Final = FixedLengthUTF32, VariableLengthUTF8 + +TimeDType = DateTime64 | TimeDelta64 +TIME_DTYPE: Final = DateTime64, TimeDelta64 + +BytesDType = RawBytes | NullTerminatedBytes | VariableLengthBytes +BYTES_DTYPE: Final = RawBytes, NullTerminatedBytes, VariableLengthBytes + +AnyDType = ( + Bool + | IntegerDType + | FloatDType + | ComplexFloatDType + | StringDType + | BytesDType + | Structured + | TimeDType + | VariableLengthBytes +) +# mypy has trouble inferring the type of variablelengthstring dtype, because its class definition +# depends on the installed numpy version. That's why the type: ignore statement is needed here. +ANY_DTYPE: Final = ( + Bool, + *INTEGER_DTYPE, + *FLOAT_DTYPE, + *COMPLEX_FLOAT_DTYPE, + *STRING_DTYPE, + *BYTES_DTYPE, + Structured, + *TIME_DTYPE, + VariableLengthBytes, +) + +# This type models inputs that can be coerced to a ZDType +ZDTypeLike: TypeAlias = npt.DTypeLike | ZDType[TBaseDType, TBaseScalar] | Mapping[str, JSON] | str + +for dtype in ANY_DTYPE: + # mypy does not know that all the elements of ANY_DTYPE are subclasses of ZDType + data_type_registry.register(dtype._zarr_v3_name, dtype) # type: ignore[arg-type] + + +# TODO: find a better name for this function +def get_data_type_from_native_dtype(dtype: npt.DTypeLike) -> ZDType[TBaseDType, TBaseScalar]: + """ + Get a data type wrapper (an instance of ``ZDType``) from a native data type, e.g. a numpy dtype. + """ + if not isinstance(dtype, np.dtype): + na_dtype: np.dtype[np.generic] + if isinstance(dtype, list): + # this is a valid _VoidDTypeLike check + na_dtype = np.dtype([tuple(d) for d in dtype]) + else: + na_dtype = np.dtype(dtype) + else: + na_dtype = dtype + return data_type_registry.match_dtype(dtype=na_dtype) + + +def get_data_type_from_json( + dtype_spec: DTypeJSON, *, zarr_format: ZarrFormat +) -> ZDType[TBaseDType, TBaseScalar]: + """ + Given a JSON representation of a data type and a Zarr format version, + attempt to create a ZDType instance from the registered ZDType classes. + """ + return data_type_registry.match_json(dtype_spec, zarr_format=zarr_format) + + +def parse_data_type( + dtype_spec: ZDTypeLike, + *, + zarr_format: ZarrFormat, +) -> ZDType[TBaseDType, TBaseScalar]: + """ + Interpret the input as a ZDType instance. + """ + if isinstance(dtype_spec, ZDType): + return dtype_spec + # dict and zarr_format 3 means that we have a JSON object representation of the dtype + if zarr_format == 3 and isinstance(dtype_spec, Mapping): + return get_data_type_from_json(dtype_spec, zarr_format=3) + # otherwise, we have either a numpy dtype string, or a zarr v3 dtype string, and in either case + # we can create a numpy dtype from it, and do the dtype inference from that + return get_data_type_from_native_dtype(dtype_spec) # type: ignore[arg-type] diff --git a/src/zarr/core/dtype/common.py b/src/zarr/core/dtype/common.py new file mode 100644 index 0000000000..6f61b6775e --- /dev/null +++ b/src/zarr/core/dtype/common.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import warnings +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import ( + ClassVar, + Final, + Generic, + Literal, + TypedDict, + TypeGuard, + TypeVar, +) + +from zarr.core.common import NamedConfig + +EndiannessStr = Literal["little", "big"] +ENDIANNESS_STR: Final = "little", "big" + +SpecialFloatStrings = Literal["NaN", "Infinity", "-Infinity"] +SPECIAL_FLOAT_STRINGS: Final = ("NaN", "Infinity", "-Infinity") + +JSONFloatV2 = float | SpecialFloatStrings +JSONFloatV3 = float | SpecialFloatStrings | str + +ObjectCodecID = Literal["vlen-utf8", "vlen-bytes", "vlen-array", "pickle", "json2", "msgpack2"] +# These are the ids of the known object codecs for zarr v2. +OBJECT_CODEC_IDS: Final = ("vlen-utf8", "vlen-bytes", "vlen-array", "pickle", "json2", "msgpack2") + +# This is a wider type than our standard JSON type because we need +# to work with typeddict objects which are assignable to Mapping[str, object] +DTypeJSON = str | int | float | Sequence["DTypeJSON"] | None | Mapping[str, object] + +# The DTypeJSON_V2 type exists because ZDType.from_json takes a single argument, which must contain +# all the information necessary to decode the data type. Zarr v2 supports multiple distinct +# data types that all used the "|O" data type identifier. These data types can only be +# discriminated on the basis of their "object codec", i.e. a special data type specific +# compressor or filter. So to figure out what data type a zarr v2 array has, we need the +# data type identifier from metadata, as well as an object codec id if the data type identifier +# is "|O". +# So we will pack the name of the dtype alongside the name of the object codec id, if applicable, +# in a single dict, and pass that to the data type inference logic. +# These type variables have a very wide bound because the individual zdtype +# classes can perform a very specific type check. + +# This is the JSON representation of a structured dtype in zarr v2 +StructuredName_V2 = Sequence["str | StructuredName_V2"] + +# This models the type of the name a dtype might have in zarr v2 array metadata +DTypeName_V2 = StructuredName_V2 | str + +TDTypeNameV2_co = TypeVar("TDTypeNameV2_co", bound=DTypeName_V2, covariant=True) +TObjectCodecID_co = TypeVar("TObjectCodecID_co", bound=None | str, covariant=True) + + +class DTypeConfig_V2(TypedDict, Generic[TDTypeNameV2_co, TObjectCodecID_co]): + name: TDTypeNameV2_co + object_codec_id: TObjectCodecID_co + + +DTypeSpec_V2 = DTypeConfig_V2[DTypeName_V2, None | str] + + +def check_structured_dtype_v2_inner(data: object) -> TypeGuard[StructuredName_V2]: + """ + A type guard for the inner elements of a structured dtype. This is a recursive check because + the type is itself recursive. + + This check ensures that all the elements are 2-element sequences beginning with a string + and ending with either another string or another 2-element sequence beginning with a string and + ending with another instance of that type. + """ + if isinstance(data, (str, Mapping)): + return False + if not isinstance(data, Sequence): + return False + if len(data) != 2: + return False + if not (isinstance(data[0], str)): + return False + if isinstance(data[-1], str): + return True + elif isinstance(data[-1], Sequence): + return check_structured_dtype_v2_inner(data[-1]) + return False + + +def check_structured_dtype_name_v2(data: Sequence[object]) -> TypeGuard[StructuredName_V2]: + return all(check_structured_dtype_v2_inner(d) for d in data) + + +def check_dtype_name_v2(data: object) -> TypeGuard[DTypeName_V2]: + """ + Type guard for narrowing the type of a python object to an valid zarr v2 dtype name. + """ + if isinstance(data, str): + return True + elif isinstance(data, Sequence): + return check_structured_dtype_name_v2(data) + return False + + +def check_dtype_spec_v2(data: object) -> TypeGuard[DTypeSpec_V2]: + """ + Type guard for narrowing a python object to an instance of DTypeSpec_V2 + """ + if not isinstance(data, Mapping): + return False + if set(data.keys()) != {"name", "object_codec_id"}: + return False + if not check_dtype_name_v2(data["name"]): + return False + return isinstance(data["object_codec_id"], str | None) + + +# By comparison, The JSON representation of a dtype in zarr v3 is much simpler. +# It's either a string, or a structured dict +DTypeSpec_V3 = str | NamedConfig[str, Mapping[str, object]] + + +def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: + """ + Type guard for narrowing the type of a python object to an instance of + DTypeSpec_V3, i.e either a string or a dict with a "name" field that's a string and a + "configuration" field that's a mapping with string keys. + """ + if isinstance(data, str) or ( # noqa: SIM103 + isinstance(data, Mapping) + and set(data.keys()) == {"name", "configuration"} + and isinstance(data["configuration"], Mapping) + and all(isinstance(k, str) for k in data["configuration"]) + ): + return True + return False + + +def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: + """ + Return the array metadata form of the dtype JSON representation. For the Zarr V3 form of dtype + metadata, this is a no-op. For the Zarr V2 form of dtype metadata, this unpacks the dtype name. + """ + if isinstance(data, Mapping) and set(data.keys()) == {"name", "object_codec_id"}: + return data["name"] + return data + + +class DataTypeValidationError(ValueError): ... + + +class ScalarTypeValidationError(ValueError): ... + + +@dataclass(frozen=True) +class HasLength: + """ + A mix-in class for data types with a length attribute, such as fixed-size collections + of unicode strings, or bytes. + """ + + length: int + + +@dataclass(frozen=True) +class HasEndianness: + """ + A mix-in class for data types with an endianness attribute + """ + + endianness: EndiannessStr = "little" + + +@dataclass(frozen=True) +class HasItemSize: + """ + A mix-in class for data types with an item size attribute. + This mix-in bears a property ``item_size``, which denotes the size of each element of the data + type, in bytes. + """ + + @property + def item_size(self) -> int: + raise NotImplementedError + + +@dataclass(frozen=True) +class HasObjectCodec: + """ + A mix-in class for data types that require an object codec id. + This class bears the property ``object_codec_id``, which is the string name of an object + codec that is required to encode and decode the data type. + + In zarr-python 2.x certain data types like variable-length strings or variable-length arrays + used the catch-all numpy "object" data type for their in-memory representation. But these data + types cannot be stored as numpy object data types, because the object data type does not define + a fixed memory layout. So these data types required a special codec, called an "object codec", + that effectively defined a compact representation for the data type, which was used to encode + and decode the data type. + + Zarr-python 2.x would not allow the creation of arrays with the "object" data type if an object + codec was not specified, and thus the name of the object codec is effectively part of the data + type model. + """ + + object_codec_id: ClassVar[str] + + +class UnstableSpecificationWarning(FutureWarning): ... + + +def v3_unstable_dtype_warning(dtype: object) -> None: + """ + Emit this warning when a data type does not have a stable zarr v3 spec + """ + msg = ( + f"The data type ({dtype}) does not have a Zarr V3 specification. " + "That means that the representation of array saved with this data type may change without " + "warning in a future version of Zarr Python. " + "Arrays stored with this data type may be unreadable by other Zarr libraries. " + "Use this data type at your own risk! " + "Check https://github.com/zarr-developers/zarr-extensions/tree/main/data-types for the " + "status of data type specifications for Zarr V3." + ) + warnings.warn(msg, category=UnstableSpecificationWarning, stacklevel=2) diff --git a/src/zarr/core/dtype/npy/__init__.py b/src/zarr/core/dtype/npy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/zarr/core/dtype/npy/bool.py b/src/zarr/core/dtype/npy/bool.py new file mode 100644 index 0000000000..d8d52468bf --- /dev/null +++ b/src/zarr/core/dtype/npy/bool.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Literal, Self, TypeGuard, overload + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasItemSize, + check_dtype_spec_v2, +) +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + + +@dataclass(frozen=True, kw_only=True, slots=True) +class Bool(ZDType[np.dtypes.BoolDType, np.bool_], HasItemSize): + """ + Wrapper for numpy boolean dtype. + + Attributes + ---------- + name : str + The name of the dtype. + dtype_cls : ClassVar[type[np.dtypes.BoolDType]] + The numpy dtype class. + """ + + _zarr_v3_name: ClassVar[Literal["bool"]] = "bool" + _zarr_v2_name: ClassVar[Literal["|b1"]] = "|b1" + dtype_cls = np.dtypes.BoolDType + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create a Bool from a np.dtype('bool') instance. + """ + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self: Self) -> np.dtypes.BoolDType: + """ + Create a NumPy boolean dtype instance from this ZDType + """ + return self.dtype_cls() + + @classmethod + def _check_json_v2( + cls, + data: DTypeJSON, + ) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]: + """ + Check that the input is a valid JSON representation of a Bool. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] == cls._zarr_v2_name + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["bool"]]: + return data == cls._zarr_v3_name + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_name!r}" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["|b1"], None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> Literal["bool"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal["|b1"], None] | Literal["bool"]: + if zarr_format == 2: + return {"name": self._zarr_v2_name, "object_codec_id": None} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> bool: + # Anything can become a bool + return True + + def cast_scalar(self, data: object) -> np.bool_: + if self._check_scalar(data): + return np.bool_(data) + msg = f"Cannot convert object with type {type(data)} to a numpy boolean." + raise TypeError(msg) + + def default_scalar(self) -> np.bool_: + """ + Get the default value for the boolean dtype. + + Returns + ------- + np.bool_ + The default value. + """ + return np.False_ + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> bool: + """ + Convert a scalar to a python bool. + + Parameters + ---------- + data : object + The value to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + bool + The JSON-serializable format. + """ + return bool(data) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bool_: + """ + Read a JSON-serializable value as a numpy boolean scalar. + + Parameters + ---------- + data : JSON + The JSON-serializable value. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + np.bool_ + The numpy boolean scalar. + """ + if self._check_scalar(data): + return np.bool_(data) + raise TypeError(f"Invalid type: {data}. Expected a boolean.") # pragma: no cover + + @property + def item_size(self) -> int: + return 1 diff --git a/src/zarr/core/dtype/npy/bytes.py b/src/zarr/core/dtype/npy/bytes.py new file mode 100644 index 0000000000..e363c75053 --- /dev/null +++ b/src/zarr/core/dtype/npy/bytes.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import base64 +import re +from dataclasses import dataclass +from typing import Any, ClassVar, Literal, Self, TypedDict, TypeGuard, cast, overload + +import numpy as np + +from zarr.core.common import JSON, NamedConfig, ZarrFormat +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasItemSize, + HasLength, + HasObjectCodec, + check_dtype_spec_v2, + v3_unstable_dtype_warning, +) +from zarr.core.dtype.npy.common import check_json_str +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +BytesLike = np.bytes_ | str | bytes | int + + +class FixedLengthBytesConfig(TypedDict): + length_bytes: int + + +NullTerminatedBytesJSONV3 = NamedConfig[Literal["null_terminated_bytes"], FixedLengthBytesConfig] +RawBytesJSONV3 = NamedConfig[Literal["raw_bytes"], FixedLengthBytesConfig] + + +@dataclass(frozen=True, kw_only=True) +class NullTerminatedBytes(ZDType[np.dtypes.BytesDType[int], np.bytes_], HasLength, HasItemSize): + dtype_cls = np.dtypes.BytesDType + _zarr_v3_name: ClassVar[Literal["null_terminated_bytes"]] = "null_terminated_bytes" + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls(length=dtype.itemsize) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.BytesDType[int]: + return self.dtype_cls(self.length) + + @classmethod + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid representation of a numpy S dtype. We expect + something like ``{"name": "|S10", "object_codec_id": None}`` + """ + return ( + check_dtype_spec_v2(data) + and isinstance(data["name"], str) + and re.match(r"^\|S\d+$", data["name"]) is not None + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[NullTerminatedBytesJSONV3]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and isinstance(data["configuration"], dict) + and "length_bytes" in data["configuration"] + ) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + name = data["name"] + return cls(length=int(name[2:])) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a string like '|S1', '|S2', etc" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls(length=data["configuration"]["length_bytes"]) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> NullTerminatedBytesJSONV3: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[str, None] | NullTerminatedBytesJSONV3: + if zarr_format == 2: + return {"name": self.to_native_dtype().str, "object_codec_id": None} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + return { + "name": self._zarr_v3_name, + "configuration": {"length_bytes": self.length}, + } + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[BytesLike]: + # this is generous for backwards compatibility + return isinstance(data, BytesLike) + + def _cast_scalar_unchecked(self, data: BytesLike) -> np.bytes_: + # We explicitly truncate the result because of the following numpy behavior: + # >>> x = np.dtype('S3').type('hello world') + # >>> x + # np.bytes_(b'hello world') + # >>> x.dtype + # dtype('S11') + + if isinstance(data, int): + return self.to_native_dtype().type(str(data)[: self.length]) + else: + return self.to_native_dtype().type(data[: self.length]) + + def cast_scalar(self, data: object) -> np.bytes_: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy bytes scalar." + raise TypeError(msg) + + def default_scalar(self) -> np.bytes_: + return np.bytes_(b"") + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + as_bytes = self.cast_scalar(data) + return base64.standard_b64encode(as_bytes).decode("ascii") + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.bytes_: + if check_json_str(data): + return self.to_native_dtype().type(base64.standard_b64decode(data.encode("ascii"))) + raise TypeError( + f"Invalid type: {data}. Expected a base64-encoded string." + ) # pragma: no cover + + @property + def item_size(self) -> int: + return self.length + + +@dataclass(frozen=True, kw_only=True) +class RawBytes(ZDType[np.dtypes.VoidDType[int], np.void], HasLength, HasItemSize): + # np.dtypes.VoidDType is specified in an odd way in numpy + # it cannot be used to create instances of the dtype + # so we have to tell mypy to ignore this here + dtype_cls = np.dtypes.VoidDType # type: ignore[assignment] + _zarr_v3_name: ClassVar[Literal["raw_bytes"]] = "raw_bytes" + + @classmethod + def _check_native_dtype( + cls: type[Self], dtype: TBaseDType + ) -> TypeGuard[np.dtypes.VoidDType[Any]]: + """ + Numpy void dtype comes in two forms: + * If the ``fields`` attribute is ``None``, then the dtype represents N raw bytes. + * If the ``fields`` attribute is not ``None``, then the dtype represents a structured dtype, + + In this check we ensure that ``fields`` is ``None``. + + Parameters + ---------- + dtype : TDType + The dtype to check. + + Returns + ------- + Bool + True if the dtype matches, False otherwise. + """ + return cls.dtype_cls is type(dtype) and dtype.fields is None # type: ignore[has-type] + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls(length=dtype.itemsize) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" # type: ignore[has-type] + ) + + def to_native_dtype(self) -> np.dtypes.VoidDType[int]: + # Numpy does not allow creating a void type + # by invoking np.dtypes.VoidDType directly + return cast("np.dtypes.VoidDType[int]", np.dtype(f"V{self.length}")) + + @classmethod + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid representation of a numpy S dtype. We expect + something like ``{"name": "|V10", "object_codec_id": None}`` + """ + return ( + check_dtype_spec_v2(data) + and isinstance(data["name"], str) + and re.match(r"^\|V\d+$", data["name"]) is not None + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[RawBytesJSONV3]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and isinstance(data["configuration"], dict) + and set(data["configuration"].keys()) == {"length_bytes"} + ) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + name = data["name"] + return cls(length=int(name[2:])) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a string like '|V1', '|V2', etc" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls(length=data["configuration"]["length_bytes"]) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> RawBytesJSONV3: ... + + def to_json(self, zarr_format: ZarrFormat) -> DTypeConfig_V2[str, None] | RawBytesJSONV3: + if zarr_format == 2: + return {"name": self.to_native_dtype().str, "object_codec_id": None} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + return {"name": self._zarr_v3_name, "configuration": {"length_bytes": self.length}} + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> bool: + return isinstance(data, np.bytes_ | str | bytes | np.void) + + def _cast_scalar_unchecked(self, data: object) -> np.void: + native_dtype = self.to_native_dtype() + # Without the second argument, numpy will return a void scalar for dtype V1. + # The second argument ensures that, if native_dtype is something like V10, + # the result will actually be a V10 scalar. + return native_dtype.type(data, native_dtype) + + def cast_scalar(self, data: object) -> np.void: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy void scalar." + raise TypeError(msg) + + def default_scalar(self) -> np.void: + return self.to_native_dtype().type(("\x00" * self.length).encode("ascii")) + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + return base64.standard_b64encode(self.cast_scalar(data).tobytes()).decode("ascii") + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void: + if check_json_str(data): + return self.to_native_dtype().type(base64.standard_b64decode(data)) + raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover + + @property + def item_size(self) -> int: + return self.length + + +@dataclass(frozen=True, kw_only=True) +class VariableLengthBytes(ZDType[np.dtypes.ObjectDType, bytes], HasObjectCodec): + dtype_cls = np.dtypes.ObjectDType + _zarr_v3_name: ClassVar[Literal["variable_length_bytes"]] = "variable_length_bytes" + object_codec_id: ClassVar[Literal["vlen-bytes"]] = "vlen-bytes" + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.ObjectDType: + return self.dtype_cls() + + @classmethod + def _check_json_v2( + cls, + data: DTypeJSON, + ) -> TypeGuard[DTypeConfig_V2[Literal["|O"], Literal["vlen-bytes"]]]: + """ + Check that the input is a valid JSON representation of a numpy O dtype, and that the + object codec id is appropriate for variable-length UTF-8 strings. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] == "|O" + and data["object_codec_id"] == cls.object_codec_id + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["variable_length_bytes"]]: + return data == cls._zarr_v3_name + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string '|O' and an object_codec_id of {cls.object_codec_id}" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json( + self, zarr_format: Literal[2] + ) -> DTypeConfig_V2[Literal["|O"], Literal["vlen-bytes"]]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> Literal["variable_length_bytes"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal["|O"], Literal["vlen-bytes"]] | Literal["variable_length_bytes"]: + if zarr_format == 2: + return {"name": "|O", "object_codec_id": self.object_codec_id} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def default_scalar(self) -> bytes: + return b"" + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + return base64.standard_b64encode(data).decode("ascii") # type: ignore[arg-type] + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> bytes: + if check_json_str(data): + return base64.standard_b64decode(data.encode("ascii")) + raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[BytesLike]: + return isinstance(data, BytesLike) + + def _cast_scalar_unchecked(self, data: BytesLike) -> bytes: + if isinstance(data, str): + return bytes(data, encoding="utf-8") + return bytes(data) + + def cast_scalar(self, data: object) -> bytes: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to bytes." + raise TypeError(msg) diff --git a/src/zarr/core/dtype/npy/common.py b/src/zarr/core/dtype/npy/common.py new file mode 100644 index 0000000000..264561f25c --- /dev/null +++ b/src/zarr/core/dtype/npy/common.py @@ -0,0 +1,503 @@ +from __future__ import annotations + +import base64 +import struct +import sys +from collections.abc import Sequence +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsComplex, + SupportsFloat, + SupportsIndex, + SupportsInt, + TypeGuard, + TypeVar, +) + +import numpy as np + +from zarr.core.dtype.common import ( + ENDIANNESS_STR, + SPECIAL_FLOAT_STRINGS, + EndiannessStr, + JSONFloatV2, + JSONFloatV3, +) + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + +IntLike = SupportsInt | SupportsIndex | bytes | str +FloatLike = SupportsIndex | SupportsFloat | bytes | str +ComplexLike = SupportsFloat | SupportsIndex | SupportsComplex | bytes | str | None +DateTimeUnit = Literal[ + "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", "generic" +] +DATETIME_UNIT: Final = ( + "Y", + "M", + "W", + "D", + "h", + "m", + "s", + "ms", + "us", + "μs", + "ns", + "ps", + "fs", + "as", + "generic", +) + +NumpyEndiannessStr = Literal[">", "<", "="] +NUMPY_ENDIANNESS_STR: Final = ">", "<", "=" + +TFloatDType_co = TypeVar( + "TFloatDType_co", + bound=np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType, + covariant=True, +) +TFloatScalar_co = TypeVar( + "TFloatScalar_co", bound=np.float16 | np.float32 | np.float64, covariant=True +) + +TComplexDType_co = TypeVar( + "TComplexDType_co", bound=np.dtypes.Complex64DType | np.dtypes.Complex128DType, covariant=True +) +TComplexScalar_co = TypeVar("TComplexScalar_co", bound=np.complex64 | np.complex128, covariant=True) + + +def endianness_from_numpy_str(endianness: NumpyEndiannessStr) -> EndiannessStr: + """ + Convert a numpy endianness string literal to a human-readable literal value. + + Parameters + ---------- + endianness : Literal[">", "<", "="] + The numpy string representation of the endianness. + + Returns + ------- + Endianness + The human-readable representation of the endianness. + + Raises + ------ + ValueError + If the endianness is invalid. + """ + match endianness: + case "=": + # Use the local system endianness + return sys.byteorder + case "<": + return "little" + case ">": + return "big" + raise ValueError(f"Invalid endianness: {endianness!r}. Expected one of {NUMPY_ENDIANNESS_STR}") + + +def endianness_to_numpy_str(endianness: EndiannessStr) -> NumpyEndiannessStr: + """ + Convert an endianness literal to its numpy string representation. + + Parameters + ---------- + endianness : Endianness + The endianness to convert. + + Returns + ------- + Literal[">", "<"] + The numpy string representation of the endianness. + + Raises + ------ + ValueError + If the endianness is invalid. + """ + match endianness: + case "little": + return "<" + case "big": + return ">" + raise ValueError( + f"Invalid endianness: {endianness!r}. Expected one of {ENDIANNESS_STR} or None" + ) + + +def get_endianness_from_numpy_dtype(dtype: np.dtype[np.generic]) -> EndiannessStr: + """ + Gets the endianness from a numpy dtype that has an endianness. This function will + raise a ValueError if the numpy data type does not have a concrete endianness. + """ + endianness = dtype.byteorder + if dtype.byteorder in NUMPY_ENDIANNESS_STR: + return endianness_from_numpy_str(endianness) # type: ignore [arg-type] + raise ValueError(f"The dtype {dtype} has an unsupported endianness: {endianness}") + + +def float_from_json_v2(data: JSONFloatV2) -> float: + """ + Convert a JSON float to a float (Zarr v2). + + Parameters + ---------- + data : JSONFloat + The JSON float to convert. + + Returns + ------- + float + The float value. + """ + match data: + case "NaN": + return float("nan") + case "Infinity": + return float("inf") + case "-Infinity": + return float("-inf") + case _: + return float(data) + + +def float_from_json_v3(data: JSONFloatV3) -> float: + """ + Convert a JSON float to a float (v3). + + Parameters + ---------- + data : JSONFloat + The JSON float to convert. + + Returns + ------- + float + The float value. + + Notes + ----- + Zarr V3 allows floats to be stored as hex strings. To quote the spec: + "...for float32, "NaN" is equivalent to "0x7fc00000". + This representation is the only way to specify a NaN value other than the specific NaN value + denoted by "NaN"." + """ + + if isinstance(data, str): + if data in SPECIAL_FLOAT_STRINGS: + return float_from_json_v2(data) # type: ignore[arg-type] + if not data.startswith("0x"): + msg = ( + f"Invalid float value: {data!r}. Expected a string starting with the hex prefix" + " '0x', or one of 'NaN', 'Infinity', or '-Infinity'." + ) + raise ValueError(msg) + if len(data[2:]) == 4: + dtype_code = ">e" + elif len(data[2:]) == 8: + dtype_code = ">f" + elif len(data[2:]) == 16: + dtype_code = ">d" + else: + msg = ( + f"Invalid hexadecimal float value: {data!r}. " + "Expected the '0x' prefix to be followed by 4, 8, or 16 numeral characters" + ) + raise ValueError(msg) + return float(struct.unpack(dtype_code, bytes.fromhex(data[2:]))[0]) + return float_from_json_v2(data) + + +def bytes_from_json(data: str, *, zarr_format: ZarrFormat) -> bytes: + """ + Convert a JSON string to bytes + + Parameters + ---------- + data : str + The JSON string to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + bytes + The bytes. + """ + if zarr_format == 2: + return base64.b64decode(data.encode("ascii")) + # TODO: differentiate these as needed. This is a spec question. + if zarr_format == 3: + return base64.b64decode(data.encode("ascii")) + raise ValueError(f"Invalid zarr format: {zarr_format}. Expected 2 or 3.") # pragma: no cover + + +def bytes_to_json(data: bytes, zarr_format: ZarrFormat) -> str: + """ + Convert bytes to JSON. + + Parameters + ---------- + data : bytes + The bytes to store. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + The bytes encoded as ascii using the base64 alphabet. + """ + # TODO: decide if we are going to make this implementation zarr format-specific + return base64.b64encode(data).decode("ascii") + + +def float_to_json_v2(data: float | np.floating[Any]) -> JSONFloatV2: + """ + Convert a float to JSON (v2). + + Parameters + ---------- + data : float or np.floating + The float value to convert. + + Returns + ------- + JSONFloat + The JSON representation of the float. + """ + if np.isnan(data): + return "NaN" + elif np.isinf(data): + return "Infinity" if data > 0 else "-Infinity" + return float(data) + + +def float_to_json_v3(data: float | np.floating[Any]) -> JSONFloatV3: + """ + Convert a float to JSON (v3). + + Parameters + ---------- + data : float or np.floating + The float value to convert. + + Returns + ------- + JSONFloat + The JSON representation of the float. + """ + # v3 can in principle handle distinct NaN values, but numpy does not represent these explicitly + # so we just reuse the v2 routine here + return float_to_json_v2(data) + + +def complex_float_to_json_v3( + data: complex | np.complexfloating[Any, Any], +) -> tuple[JSONFloatV3, JSONFloatV3]: + """ + Convert a complex number to JSON as defined by the Zarr V3 spec. + + Parameters + ---------- + data : complex or np.complexfloating + The complex value to convert. + + Returns + ------- + tuple[JSONFloat, JSONFloat] + The JSON representation of the complex number. + """ + return float_to_json_v3(data.real), float_to_json_v3(data.imag) + + +def complex_float_to_json_v2( + data: complex | np.complexfloating[Any, Any], +) -> tuple[JSONFloatV2, JSONFloatV2]: + """ + Convert a complex number to JSON as defined by the Zarr V2 spec. + + Parameters + ---------- + data : complex | np.complexfloating + The complex value to convert. + + Returns + ------- + tuple[JSONFloat, JSONFloat] + The JSON representation of the complex number. + """ + return float_to_json_v2(data.real), float_to_json_v2(data.imag) + + +def complex_float_from_json_v2(data: tuple[JSONFloatV2, JSONFloatV2]) -> complex: + """ + Convert a JSON complex float to a complex number (v2). + + Parameters + ---------- + data : tuple[JSONFloat, JSONFloat] + The JSON complex float to convert. + + Returns + ------- + np.complexfloating + The complex number. + """ + return complex(float_from_json_v2(data[0]), float_from_json_v2(data[1])) + + +def complex_float_from_json_v3(data: tuple[JSONFloatV3, JSONFloatV3]) -> complex: + """ + Convert a JSON complex float to a complex number (v3). + + Parameters + ---------- + data : tuple[JSONFloat, JSONFloat] + The JSON complex float to convert. + + Returns + ------- + np.complexfloating + The complex number. + """ + return complex(float_from_json_v3(data[0]), float_from_json_v3(data[1])) + + +def check_json_float_v2(data: JSON) -> TypeGuard[JSONFloatV2]: + """ + Check if a JSON value represents a float (v2). + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a float, False otherwise. + """ + if data == "NaN" or data == "Infinity" or data == "-Infinity": + return True + return isinstance(data, float | int) + + +def check_json_float_v3(data: JSON) -> TypeGuard[JSONFloatV3]: + """ + Check if a JSON value represents a float (v3). + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a float, False otherwise. + """ + return check_json_float_v2(data) or (isinstance(data, str) and data.startswith("0x")) + + +def check_json_complex_float_v2(data: JSON) -> TypeGuard[tuple[JSONFloatV2, JSONFloatV2]]: + """ + Check if a JSON value represents a complex float, as per the behavior of zarr-python 2.x + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a complex float, False otherwise. + """ + return ( + not isinstance(data, str) + and isinstance(data, Sequence) + and len(data) == 2 + and check_json_float_v2(data[0]) + and check_json_float_v2(data[1]) + ) + + +def check_json_complex_float_v3(data: JSON) -> TypeGuard[tuple[JSONFloatV3, JSONFloatV3]]: + """ + Check if a JSON value represents a complex float, as per the zarr v3 spec + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a complex float, False otherwise. + """ + return ( + not isinstance(data, str) + and isinstance(data, Sequence) + and len(data) == 2 + and check_json_float_v3(data[0]) + and check_json_float_v3(data[1]) + ) + + +def check_json_int(data: JSON) -> TypeGuard[int]: + """ + Check if a JSON value is an integer. + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is an integer, False otherwise. + """ + return bool(isinstance(data, int)) + + +def check_json_str(data: JSON) -> TypeGuard[str]: + """ + Check if a JSON value is a string. + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a string, False otherwise. + """ + return bool(isinstance(data, str)) + + +def check_json_bool(data: JSON) -> TypeGuard[bool]: + """ + Check if a JSON value is a boolean. + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + Bool + True if the data is a boolean, False otherwise. + """ + return isinstance(data, bool) diff --git a/src/zarr/core/dtype/npy/complex.py b/src/zarr/core/dtype/npy/complex.py new file mode 100644 index 0000000000..38e506f1bc --- /dev/null +++ b/src/zarr/core/dtype/npy/complex.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Self, + TypeGuard, + overload, +) + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasEndianness, + HasItemSize, + check_dtype_spec_v2, +) +from zarr.core.dtype.npy.common import ( + ComplexLike, + TComplexDType_co, + TComplexScalar_co, + check_json_complex_float_v2, + check_json_complex_float_v3, + complex_float_from_json_v2, + complex_float_from_json_v3, + complex_float_to_json_v2, + complex_float_to_json_v3, + endianness_to_numpy_str, + get_endianness_from_numpy_dtype, +) +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + + +@dataclass(frozen=True) +class BaseComplex(ZDType[TComplexDType_co, TComplexScalar_co], HasEndianness, HasItemSize): + # This attribute holds the possible zarr v2 JSON names for the data type + _zarr_v2_names: ClassVar[tuple[str, ...]] + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> TComplexDType_co: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) # type: ignore[return-value] + + @classmethod + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid JSON representation of this data type. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] in cls._zarr_v2_names + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[str]: + return data == cls._zarr_v3_name + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {cls._zarr_v3_name}." + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> str: ... + + def to_json(self, zarr_format: ZarrFormat) -> DTypeConfig_V2[str, None] | str: + """ + Convert the wrapped data type to a JSON-serializable form. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + The JSON-serializable representation of the wrapped data type + """ + if zarr_format == 2: + return {"name": self.to_native_dtype().str, "object_codec_id": None} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[ComplexLike]: + return isinstance(data, ComplexLike) + + def _cast_scalar_unchecked(self, data: ComplexLike) -> TComplexScalar_co: + return self.to_native_dtype().type(data) # type: ignore[return-value] + + def cast_scalar(self, data: object) -> TComplexScalar_co: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy float scalar." + raise TypeError(msg) + + def default_scalar(self) -> TComplexScalar_co: + """ + Get the default value, which is 0 cast to this dtype + + Returns + ------- + Int scalar + The default value. + """ + return self._cast_scalar_unchecked(0) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> TComplexScalar_co: + """ + Read a JSON-serializable value as a numpy float. + + Parameters + ---------- + data : JSON + The JSON-serializable value. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + TScalar_co + The numpy float. + """ + if zarr_format == 2: + if check_json_complex_float_v2(data): + return self._cast_scalar_unchecked(complex_float_from_json_v2(data)) + raise TypeError( + f"Invalid type: {data}. Expected a float or a special string encoding of a float." + ) + elif zarr_format == 3: + if check_json_complex_float_v3(data): + return self._cast_scalar_unchecked(complex_float_from_json_v3(data)) + raise TypeError( + f"Invalid type: {data}. Expected a float or a special string encoding of a float." + ) + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> JSON: + """ + Convert an object to a JSON-serializable float. + + Parameters + ---------- + data : _BaseScalar + The value to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + JSON + The JSON-serializable form of the complex number, which is a list of two floats, + each of which is encoding according to a zarr-format-specific encoding. + """ + if zarr_format == 2: + return complex_float_to_json_v2(self.cast_scalar(data)) + elif zarr_format == 3: + return complex_float_to_json_v3(self.cast_scalar(data)) + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + +@dataclass(frozen=True, kw_only=True) +class Complex64(BaseComplex[np.dtypes.Complex64DType, np.complex64]): + dtype_cls = np.dtypes.Complex64DType + _zarr_v3_name: ClassVar[Literal["complex64"]] = "complex64" + _zarr_v2_names: ClassVar[tuple[str, ...]] = (">c8", " int: + return 8 + + +@dataclass(frozen=True, kw_only=True) +class Complex128(BaseComplex[np.dtypes.Complex128DType, np.complex128], HasEndianness): + dtype_cls = np.dtypes.Complex128DType + _zarr_v3_name: ClassVar[Literal["complex128"]] = "complex128" + _zarr_v2_names: ClassVar[tuple[str, ...]] = (">c16", " int: + return 16 diff --git a/src/zarr/core/dtype/npy/float.py b/src/zarr/core/dtype/npy/float.py new file mode 100644 index 0000000000..7b7243993f --- /dev/null +++ b/src/zarr/core/dtype/npy/float.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Literal, Self, TypeGuard, overload + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasEndianness, + HasItemSize, + ScalarTypeValidationError, + check_dtype_spec_v2, +) +from zarr.core.dtype.npy.common import ( + FloatLike, + TFloatDType_co, + TFloatScalar_co, + check_json_float_v2, + check_json_float_v3, + endianness_to_numpy_str, + float_from_json_v2, + float_from_json_v3, + float_to_json_v2, + float_to_json_v3, + get_endianness_from_numpy_dtype, +) +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + + +@dataclass(frozen=True) +class BaseFloat(ZDType[TFloatDType_co, TFloatScalar_co], HasEndianness, HasItemSize): + # This attribute holds the possible zarr v2 JSON names for the data type + _zarr_v2_names: ClassVar[tuple[str, ...]] + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> TFloatDType_co: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) # type: ignore[return-value] + + @classmethod + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid JSON representation of this data type. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] in cls._zarr_v2_names + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[str]: + return data == cls._zarr_v3_name + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {cls._zarr_v3_name}." + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> str: ... + + def to_json(self, zarr_format: ZarrFormat) -> DTypeConfig_V2[str, None] | str: + """ + Convert the wrapped data type to a JSON-serializable form. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + The JSON-serializable representation of the wrapped data type + """ + if zarr_format == 2: + return {"name": self.to_native_dtype().str, "object_codec_id": None} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[FloatLike]: + return isinstance(data, FloatLike) + + def _cast_scalar_unchecked(self, data: FloatLike) -> TFloatScalar_co: + return self.to_native_dtype().type(data) # type: ignore[return-value] + + def cast_scalar(self, data: object) -> TFloatScalar_co: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy float scalar." + raise ScalarTypeValidationError(msg) + + def default_scalar(self) -> TFloatScalar_co: + """ + Get the default value, which is 0 cast to this dtype + + Returns + ------- + Int scalar + The default value. + """ + return self._cast_scalar_unchecked(0) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> TFloatScalar_co: + """ + Read a JSON-serializable value as a numpy float. + + Parameters + ---------- + data : JSON + The JSON-serializable value. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + TScalar_co + The numpy float. + """ + if zarr_format == 2: + if check_json_float_v2(data): + return self._cast_scalar_unchecked(float_from_json_v2(data)) + else: + raise TypeError( + f"Invalid type: {data}. Expected a float or a special string encoding of a float." + ) + elif zarr_format == 3: + if check_json_float_v3(data): + return self._cast_scalar_unchecked(float_from_json_v3(data)) + else: + raise TypeError( + f"Invalid type: {data}. Expected a float or a special string encoding of a float." + ) + else: + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> float | str: + """ + Convert an object to a JSON-serializable float. + + Parameters + ---------- + data : _BaseScalar + The value to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + JSON + The JSON-serializable form of the float, which is potentially a number or a string. + See the zarr specifications for details on the JSON encoding for floats. + """ + if zarr_format == 2: + return float_to_json_v2(self.cast_scalar(data)) + elif zarr_format == 3: + return float_to_json_v3(self.cast_scalar(data)) + else: + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + +@dataclass(frozen=True, kw_only=True) +class Float16(BaseFloat[np.dtypes.Float16DType, np.float16]): + dtype_cls = np.dtypes.Float16DType + _zarr_v3_name = "float16" + _zarr_v2_names: ClassVar[tuple[Literal[">f2"], Literal["f2", " int: + return 2 + + +@dataclass(frozen=True, kw_only=True) +class Float32(BaseFloat[np.dtypes.Float32DType, np.float32]): + dtype_cls = np.dtypes.Float32DType + _zarr_v3_name = "float32" + _zarr_v2_names: ClassVar[tuple[Literal[">f4"], Literal["f4", " int: + return 4 + + +@dataclass(frozen=True, kw_only=True) +class Float64(BaseFloat[np.dtypes.Float64DType, np.float64]): + dtype_cls = np.dtypes.Float64DType + _zarr_v3_name = "float64" + _zarr_v2_names: ClassVar[tuple[Literal[">f8"], Literal["f8", " int: + return 8 diff --git a/src/zarr/core/dtype/npy/int.py b/src/zarr/core/dtype/npy/int.py new file mode 100644 index 0000000000..79d3ce2d47 --- /dev/null +++ b/src/zarr/core/dtype/npy/int.py @@ -0,0 +1,686 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Self, + SupportsIndex, + SupportsInt, + TypeGuard, + TypeVar, + overload, +) + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasEndianness, + HasItemSize, + check_dtype_spec_v2, +) +from zarr.core.dtype.npy.common import ( + check_json_int, + endianness_to_numpy_str, + get_endianness_from_numpy_dtype, +) +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + +_NumpyIntDType = ( + np.dtypes.Int8DType + | np.dtypes.Int16DType + | np.dtypes.Int32DType + | np.dtypes.Int64DType + | np.dtypes.UInt8DType + | np.dtypes.UInt16DType + | np.dtypes.UInt32DType + | np.dtypes.UInt64DType +) +_NumpyIntScalar = ( + np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 +) +TIntDType_co = TypeVar("TIntDType_co", bound=_NumpyIntDType, covariant=True) +TIntScalar_co = TypeVar("TIntScalar_co", bound=_NumpyIntScalar, covariant=True) +IntLike = SupportsInt | SupportsIndex | bytes | str + + +@dataclass(frozen=True) +class BaseInt(ZDType[TIntDType_co, TIntScalar_co], HasItemSize): + # This attribute holds the possible zarr V2 JSON names for the data type + _zarr_v2_names: ClassVar[tuple[str, ...]] + + @classmethod + def _check_json_v2(cls, data: object) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid JSON representation of this data type. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] in cls._zarr_v2_names + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: object) -> TypeGuard[str]: + """ + Check that a JSON value is consistent with the zarr v3 spec for this data type. + """ + return data == cls._zarr_v3_name + + def _check_scalar(self, data: object) -> TypeGuard[IntLike]: + """ + Check that a python object is IntLike + """ + return isinstance(data, IntLike) + + def _cast_scalar_unchecked(self, data: IntLike) -> TIntScalar_co: + """ + Create an integer without any type checking of the input. + """ + return self.to_native_dtype().type(data) # type: ignore[return-value] + + def cast_scalar(self, data: object) -> TIntScalar_co: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy integer." + raise TypeError(msg) + + def default_scalar(self) -> TIntScalar_co: + """ + Get the default value, which is 0 cast to this dtype + + Returns + ------- + Int scalar + The default value. + """ + return self._cast_scalar_unchecked(0) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> TIntScalar_co: + """ + Read a JSON-serializable value as a numpy int scalar. + + Parameters + ---------- + data : JSON + The JSON-serializable value. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + TScalar_co + The numpy scalar. + """ + if check_json_int(data): + return self._cast_scalar_unchecked(data) + raise TypeError(f"Invalid type: {data}. Expected an integer.") + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> int: + """ + Convert an object to JSON-serializable scalar. + + Parameters + ---------- + data : _BaseScalar + The value to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + int + The JSON-serializable form of the scalar. + """ + return int(self.cast_scalar(data)) + + +@dataclass(frozen=True, kw_only=True) +class Int8(BaseInt[np.dtypes.Int8DType, np.int8]): + dtype_cls = np.dtypes.Int8DType + _zarr_v3_name: ClassVar[Literal["int8"]] = "int8" + _zarr_v2_names: ClassVar[tuple[Literal["|i1"]]] = ("|i1",) + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create a Int8 from a np.dtype('int8') instance. + """ + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self: Self) -> np.dtypes.Int8DType: + return self.dtype_cls() + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_names[0]!r}" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["|i1"], None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> Literal["int8"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal["|i1"], None] | Literal["int8"]: + """ + Convert the wrapped data type to a JSON-serializable form. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + The JSON-serializable representation of the wrapped data type + """ + if zarr_format == 2: + return {"name": self._zarr_v2_names[0], "object_codec_id": None} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + @property + def item_size(self) -> int: + return 1 + + +@dataclass(frozen=True, kw_only=True) +class UInt8(BaseInt[np.dtypes.UInt8DType, np.uint8]): + dtype_cls = np.dtypes.UInt8DType + _zarr_v3_name: ClassVar[Literal["uint8"]] = "uint8" + _zarr_v2_names: ClassVar[tuple[Literal["|u1"]]] = ("|u1",) + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create a Bool from a np.dtype('uint8') instance. + """ + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self: Self) -> np.dtypes.UInt8DType: + return self.dtype_cls() + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_names[0]!r}" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["|u1"], None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> Literal["uint8"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal["|u1"], None] | Literal["uint8"]: + """ + Convert the wrapped data type to a JSON-serializable form. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + The JSON-serializable representation of the wrapped data type + """ + if zarr_format == 2: + return {"name": self._zarr_v2_names[0], "object_codec_id": None} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + @property + def item_size(self) -> int: + return 1 + + +@dataclass(frozen=True, kw_only=True) +class Int16(BaseInt[np.dtypes.Int16DType, np.int16], HasEndianness): + dtype_cls = np.dtypes.Int16DType + _zarr_v3_name: ClassVar[Literal["int16"]] = "int16" + _zarr_v2_names: ClassVar[tuple[Literal[">i2"], Literal["i2", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.Int16DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names!r}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">i2", " Literal["int16"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">i2", " int: + return 2 + + +@dataclass(frozen=True, kw_only=True) +class UInt16(BaseInt[np.dtypes.UInt16DType, np.uint16], HasEndianness): + dtype_cls = np.dtypes.UInt16DType + _zarr_v3_name: ClassVar[Literal["uint16"]] = "uint16" + _zarr_v2_names: ClassVar[tuple[Literal[">u2"], Literal["u2", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.UInt16DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of UInt16. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of UInt16. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">u2", " Literal["uint16"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">u2", " int: + return 2 + + +@dataclass(frozen=True, kw_only=True) +class Int32(BaseInt[np.dtypes.Int32DType, np.int32], HasEndianness): + dtype_cls = np.dtypes.Int32DType + _zarr_v3_name: ClassVar[Literal["int32"]] = "int32" + _zarr_v2_names: ClassVar[tuple[Literal[">i4"], Literal["i4", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.Int32DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">i4", " Literal["int32"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">i4", " int: + return 4 + + +@dataclass(frozen=True, kw_only=True) +class UInt32(BaseInt[np.dtypes.UInt32DType, np.uint32], HasEndianness): + dtype_cls = np.dtypes.UInt32DType + _zarr_v3_name: ClassVar[Literal["uint32"]] = "uint32" + _zarr_v2_names: ClassVar[tuple[Literal[">u4"], Literal["u4", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.UInt32DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">u4", " Literal["uint32"]: ... + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">u4", " int: + return 4 + + +@dataclass(frozen=True, kw_only=True) +class Int64(BaseInt[np.dtypes.Int64DType, np.int64], HasEndianness): + dtype_cls = np.dtypes.Int64DType + _zarr_v3_name: ClassVar[Literal["int64"]] = "int64" + _zarr_v2_names: ClassVar[tuple[Literal[">i8"], Literal["i8", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.Int64DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">i8", " Literal["int64"]: ... + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">i8", " int: + return 8 + + +@dataclass(frozen=True, kw_only=True) +class UInt64(BaseInt[np.dtypes.UInt64DType, np.uint64], HasEndianness): + dtype_cls = np.dtypes.UInt64DType + _zarr_v3_name: ClassVar[Literal["uint64"]] = "uint64" + _zarr_v2_names: ClassVar[tuple[Literal[">u8"], Literal["u8", " np.dtypes.UInt64DType: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls().newbyteorder(byte_order) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Going via numpy ensures that we get the endianness correct without + # annoying string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}." + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal[">u8", " Literal["uint64"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal[">u8", " Self: + if cls._check_native_dtype(dtype): + return cls(endianness=get_endianness_from_numpy_dtype(dtype)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + @property + def item_size(self) -> int: + return 8 diff --git a/src/zarr/core/dtype/npy/string.py b/src/zarr/core/dtype/npy/string.py new file mode 100644 index 0000000000..4a1114617a --- /dev/null +++ b/src/zarr/core/dtype/npy/string.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Protocol, + Self, + TypedDict, + TypeGuard, + overload, + runtime_checkable, +) + +import numpy as np + +from zarr.core.common import NamedConfig +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasEndianness, + HasItemSize, + HasLength, + HasObjectCodec, + check_dtype_spec_v2, + v3_unstable_dtype_warning, +) +from zarr.core.dtype.npy.common import ( + check_json_str, + endianness_to_numpy_str, + get_endianness_from_numpy_dtype, +) +from zarr.core.dtype.wrapper import TDType_co, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + from zarr.core.dtype.wrapper import TBaseDType + +_NUMPY_SUPPORTS_VLEN_STRING = hasattr(np.dtypes, "StringDType") + + +@runtime_checkable +class SupportsStr(Protocol): + def __str__(self) -> str: ... + + +class LengthBytesConfig(TypedDict): + length_bytes: int + + +# TODO: Fix this terrible name +FixedLengthUTF32JSONV3 = NamedConfig[Literal["fixed_length_utf32"], LengthBytesConfig] + + +@dataclass(frozen=True, kw_only=True) +class FixedLengthUTF32( + ZDType[np.dtypes.StrDType[int], np.str_], HasEndianness, HasLength, HasItemSize +): + dtype_cls = np.dtypes.StrDType + _zarr_v3_name: ClassVar[Literal["fixed_length_utf32"]] = "fixed_length_utf32" + code_point_bytes: ClassVar[int] = 4 # utf32 is 4 bytes per code point + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + endianness = get_endianness_from_numpy_dtype(dtype) + return cls( + length=dtype.itemsize // (cls.code_point_bytes), + endianness=endianness, + ) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.StrDType[int]: + byte_order = endianness_to_numpy_str(self.endianness) + return self.dtype_cls(self.length).newbyteorder(byte_order) + + @classmethod + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that the input is a valid JSON representation of a numpy U dtype. + """ + return ( + check_dtype_spec_v2(data) + and isinstance(data["name"], str) + and re.match(r"^[><]U\d+$", data["name"]) is not None + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[FixedLengthUTF32JSONV3]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and "configuration" in data + and isinstance(data["configuration"], dict) + and set(data["configuration"].keys()) == {"length_bytes"} + and isinstance(data["configuration"]["length_bytes"], int) + ) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> FixedLengthUTF32JSONV3: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[str, None] | FixedLengthUTF32JSONV3: + if zarr_format == 2: + return {"name": self.to_native_dtype().str, "object_codec_id": None} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + return { + "name": self._zarr_v3_name, + "configuration": {"length_bytes": self.length * self.code_point_bytes}, + } + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + # Construct the numpy dtype instead of string parsing. + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + raise DataTypeValidationError( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a string representation of a numpy U dtype." + ) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls(length=data["configuration"]["length_bytes"] // cls.code_point_bytes) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {cls._zarr_v3_name}." + raise DataTypeValidationError(msg) + + def default_scalar(self) -> np.str_: + return np.str_("") + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + return str(data) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.str_: + if check_json_str(data): + return self.to_native_dtype().type(data) + raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[str | np.str_ | bytes | int]: + # this is generous for backwards compatibility + return isinstance(data, str | np.str_ | bytes | int) + + def cast_scalar(self, data: object) -> np.str_: + if self._check_scalar(data): + # We explicitly truncate before casting because of the following numpy behavior: + # >>> x = np.dtype('U3').type('hello world') + # >>> x + # np.str_('hello world') + # >>> x.dtype + # dtype('U11') + + if isinstance(data, int): + return self.to_native_dtype().type(str(data)[: self.length]) + else: + return self.to_native_dtype().type(data[: self.length]) + raise TypeError( + f"Cannot convert object with type {type(data)} to a numpy unicode string scalar." + ) + + @property + def item_size(self) -> int: + return self.length * self.code_point_bytes + + +def check_vlen_string_json_scalar(data: object) -> TypeGuard[int | str | float]: + """ + This function checks the type of JSON-encoded variable length strings. It is generous for + backwards compatibility, as zarr-python v2 would use ints for variable length strings + fill values + """ + return isinstance(data, int | str | float) + + +# VariableLengthUTF8 is defined in two places, conditioned on the version of numpy. +# If numpy 2 is installed, then VariableLengthUTF8 is defined with the numpy variable length +# string dtype as the native dtype. Otherwise, VariableLengthUTF8 is defined with the numpy object +# dtype as the native dtype. +class UTF8Base(ZDType[TDType_co, str], HasObjectCodec): + """ + A base class for the variable length UTF-8 string data type. This class should not be used + as data type, but as a base class for other variable length string data types. + """ + + _zarr_v3_name: ClassVar[Literal["variable_length_utf8"]] = "variable_length_utf8" + object_codec_id: ClassVar[Literal["vlen-utf8"]] = "vlen-utf8" + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + @classmethod + def _check_json_v2( + cls, + data: DTypeJSON, + ) -> TypeGuard[DTypeConfig_V2[Literal["|O"], Literal["vlen-utf8"]]]: + """ + Check that the input is a valid JSON representation of a numpy O dtype, and that the + object codec id is appropriate for variable-length UTF-8 strings. + """ + return ( + check_dtype_spec_v2(data) + and data["name"] == "|O" + and data["object_codec_id"] == cls.object_codec_id + ) + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["variable_length_utf8"]]: + return data == cls._zarr_v3_name + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + return cls() + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string '|O'" + ) + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {cls._zarr_v3_name}." + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json( + self, zarr_format: Literal[2] + ) -> DTypeConfig_V2[Literal["|O"], Literal["vlen-utf8"]]: ... + @overload + def to_json(self, zarr_format: Literal[3]) -> Literal["variable_length_utf8"]: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[Literal["|O"], Literal["vlen-utf8"]] | Literal["variable_length_utf8"]: + if zarr_format == 2: + return {"name": "|O", "object_codec_id": self.object_codec_id} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def default_scalar(self) -> str: + return "" + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + raise TypeError(f"Invalid type: {data}. Expected a string.") + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> str: + if not check_vlen_string_json_scalar(data): + raise TypeError(f"Invalid type: {data}. Expected a string or number.") + return str(data) + + def _check_scalar(self, data: object) -> TypeGuard[SupportsStr]: + return isinstance(data, SupportsStr) + + def _cast_scalar_unchecked(self, data: SupportsStr) -> str: + return str(data) + + def cast_scalar(self, data: object) -> str: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + raise TypeError(f"Cannot convert object with type {type(data)} to a python string.") + + +if _NUMPY_SUPPORTS_VLEN_STRING: + + @dataclass(frozen=True, kw_only=True) + class VariableLengthUTF8(UTF8Base[np.dtypes.StringDType]): # type: ignore[type-var] + dtype_cls = np.dtypes.StringDType + + def to_native_dtype(self) -> np.dtypes.StringDType: + return self.dtype_cls() + +else: + # Numpy pre-2 does not have a variable length string dtype, so we use the Object dtype instead. + @dataclass(frozen=True, kw_only=True) + class VariableLengthUTF8(UTF8Base[np.dtypes.ObjectDType]): # type: ignore[no-redef] + dtype_cls = np.dtypes.ObjectDType + + def to_native_dtype(self) -> np.dtypes.ObjectDType: + return self.dtype_cls() diff --git a/src/zarr/core/dtype/npy/structured.py b/src/zarr/core/dtype/npy/structured.py new file mode 100644 index 0000000000..d9e1ff55ae --- /dev/null +++ b/src/zarr/core/dtype/npy/structured.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, Self, TypeGuard, cast, overload + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + DTypeSpec_V3, + HasItemSize, + StructuredName_V2, + check_dtype_spec_v2, + check_structured_dtype_name_v2, + v3_unstable_dtype_warning, +) +from zarr.core.dtype.npy.common import ( + bytes_from_json, + bytes_to_json, + check_json_str, +) +from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + +if TYPE_CHECKING: + from collections.abc import Sequence + + from zarr.core.common import JSON, NamedConfig, ZarrFormat + +StructuredScalarLike = list[object] | tuple[object, ...] | bytes | int + + +@dataclass(frozen=True, kw_only=True) +class Structured(ZDType[np.dtypes.VoidDType[int], np.void], HasItemSize): + dtype_cls = np.dtypes.VoidDType # type: ignore[assignment] + _zarr_v3_name = "structured" + fields: tuple[tuple[str, ZDType[TBaseDType, TBaseScalar]], ...] + + @classmethod + def _check_native_dtype(cls, dtype: TBaseDType) -> TypeGuard[np.dtypes.VoidDType[int]]: + """ + Check that this dtype is a numpy structured dtype + + Parameters + ---------- + dtype : np.dtypes.DTypeLike + The dtype to check. + + Returns + ------- + TypeGuard[np.dtypes.VoidDType] + True if the dtype matches, False otherwise. + """ + return isinstance(dtype, cls.dtype_cls) and dtype.fields is not None # type: ignore[has-type] + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + from zarr.core.dtype import get_data_type_from_native_dtype + + fields: list[tuple[str, ZDType[TBaseDType, TBaseScalar]]] = [] + if cls._check_native_dtype(dtype): + # fields of a structured numpy dtype are either 2-tuples or 3-tuples. we only + # care about the first element in either case. + for key, (dtype_instance, *_) in dtype.fields.items(): # type: ignore[union-attr] + dtype_wrapped = get_data_type_from_native_dtype(dtype_instance) + fields.append((key, dtype_wrapped)) + + return cls(fields=tuple(fields)) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" # type: ignore[has-type] + ) + + def to_native_dtype(self) -> np.dtypes.VoidDType[int]: + return cast( + "np.dtypes.VoidDType[int]", + np.dtype([(key, dtype.to_native_dtype()) for (key, dtype) in self.fields]), + ) + + @classmethod + def _check_json_v2( + cls, + data: DTypeJSON, + ) -> TypeGuard[DTypeConfig_V2[StructuredName_V2, None]]: + return ( + check_dtype_spec_v2(data) + and not isinstance(data["name"], str) + and check_structured_dtype_name_v2(data["name"]) + and data["object_codec_id"] is None + ) + + @classmethod + def _check_json_v3( + cls, data: DTypeJSON + ) -> TypeGuard[NamedConfig[Literal["structured"], dict[str, Sequence[tuple[str, DTypeJSON]]]]]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and isinstance(data["configuration"], dict) + and set(data["configuration"].keys()) == {"fields"} + ) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + # avoid circular import + from zarr.core.dtype import get_data_type_from_json + + if cls._check_json_v2(data): + # structured dtypes are constructed directly from a list of lists + # note that we do not handle the object codec here! this will prevent structured + # dtypes from containing object dtypes. + return cls( + fields=tuple( # type: ignore[misc] + ( # type: ignore[misc] + f_name, + get_data_type_from_json( + {"name": f_dtype, "object_codec_id": None}, zarr_format=2 + ), + ) + for f_name, f_dtype in data["name"] + ) + ) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a JSON array of arrays" + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + # avoid circular import + from zarr.core.dtype import get_data_type_from_json + + if cls._check_json_v3(data): + config = data["configuration"] + meta_fields = config["fields"] + return cls( + fields=tuple( + (f_name, get_data_type_from_json(f_dtype, zarr_format=3)) + for f_name, f_dtype in meta_fields + ) + ) + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a JSON object with the key {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[StructuredName_V2, None]: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[StructuredName_V2, None] | DTypeSpec_V3: + if zarr_format == 2: + fields = [ + [f_name, f_dtype.to_json(zarr_format=zarr_format)["name"]] + for f_name, f_dtype in self.fields + ] + return {"name": fields, "object_codec_id": None} + elif zarr_format == 3: + v3_unstable_dtype_warning(self) + fields = [ + [f_name, f_dtype.to_json(zarr_format=zarr_format)] # type: ignore[list-item] + for f_name, f_dtype in self.fields + ] + base_dict = {"name": self._zarr_v3_name} + base_dict["configuration"] = {"fields": fields} # type: ignore[assignment] + return cast("DTypeSpec_V3", base_dict) + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[StructuredScalarLike]: + # TODO: implement something more precise here! + return isinstance(data, (bytes, list, tuple, int, np.void)) + + def _cast_scalar_unchecked(self, data: StructuredScalarLike) -> np.void: + na_dtype = self.to_native_dtype() + if isinstance(data, bytes): + res = np.frombuffer(data, dtype=na_dtype)[0] + elif isinstance(data, list | tuple): + res = np.array([tuple(data)], dtype=na_dtype)[0] + else: + res = np.array([data], dtype=na_dtype)[0] + return cast("np.void", res) + + def cast_scalar(self, data: object) -> np.void: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy structured scalar." + raise TypeError(msg) + + def default_scalar(self) -> np.void: + return self._cast_scalar_unchecked(0) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void: + if check_json_str(data): + as_bytes = bytes_from_json(data, zarr_format=zarr_format) + dtype = self.to_native_dtype() + return cast("np.void", np.array([as_bytes]).view(dtype)[0]) + raise TypeError(f"Invalid type: {data}. Expected a string.") + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + return bytes_to_json(self.cast_scalar(data).tobytes(), zarr_format) + + @property + def item_size(self) -> int: + # Lets have numpy do the arithmetic here + return self.to_native_dtype().itemsize diff --git a/src/zarr/core/dtype/npy/time.py b/src/zarr/core/dtype/npy/time.py new file mode 100644 index 0000000000..1f9080475c --- /dev/null +++ b/src/zarr/core/dtype/npy/time.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Self, + TypedDict, + TypeGuard, + TypeVar, + cast, + get_args, + overload, +) + +import numpy as np + +from zarr.core.common import NamedConfig +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeConfig_V2, + DTypeJSON, + HasEndianness, + HasItemSize, + check_dtype_spec_v2, +) +from zarr.core.dtype.npy.common import ( + DATETIME_UNIT, + DateTimeUnit, + check_json_int, + endianness_to_numpy_str, + get_endianness_from_numpy_dtype, +) +from zarr.core.dtype.wrapper import TBaseDType, ZDType + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + +_DTypeName = Literal["datetime64", "timedelta64"] +TimeDeltaLike = str | int | bytes | np.timedelta64 | timedelta | None +DateTimeLike = str | int | bytes | np.datetime64 | datetime | None + + +def datetime_from_int(data: int, *, unit: DateTimeUnit, scale_factor: int) -> np.datetime64: + """ + Convert an integer to a datetime64. + + Parameters + ---------- + data : int + The integer to convert. + unit : DateTimeUnit + The unit of the datetime64. + scale_factor : int + The scale factor of the datetime64. + + Returns + ------- + np.datetime64 + The datetime64 value. + """ + dtype_name = f"datetime64[{scale_factor}{unit}]" + return cast("np.datetime64", np.int64(data).view(dtype_name)) + + +def datetimelike_to_int(data: np.datetime64 | np.timedelta64) -> int: + """ + Convert a datetime64 or a timedelta64 to an integer. + + Parameters + ---------- + data : np.datetime64 | np.timedelta64 + The value to convert. + + Returns + ------- + int + An integer representation of the scalar. + """ + return data.view(np.int64).item() + + +def check_json_time(data: JSON) -> TypeGuard[Literal["NaT"] | int]: + """ + Type guard to check if the input JSON data is the literal string "NaT" + or an integer. + """ + return check_json_int(data) or data == "NaT" + + +BaseTimeDType_co = TypeVar( + "BaseTimeDType_co", + bound=np.dtypes.TimeDelta64DType | np.dtypes.DateTime64DType, + covariant=True, +) +BaseTimeScalar_co = TypeVar( + "BaseTimeScalar_co", bound=np.timedelta64 | np.datetime64, covariant=True +) + + +class TimeConfig(TypedDict): + unit: DateTimeUnit + scale_factor: int + + +DateTime64JSONV3 = NamedConfig[Literal["numpy.datetime64"], TimeConfig] +TimeDelta64JSONV3 = NamedConfig[Literal["numpy.timedelta64"], TimeConfig] + + +@dataclass(frozen=True, kw_only=True, slots=True) +class TimeDTypeBase(ZDType[BaseTimeDType_co, BaseTimeScalar_co], HasEndianness, HasItemSize): + _zarr_v2_names: ClassVar[tuple[str, ...]] + # this attribute exists so that we can programmatically create a numpy dtype instance + # because the particular numpy dtype we are wrapping does not allow direct construction via + # cls.dtype_cls() + _numpy_name: ClassVar[_DTypeName] + scale_factor: int + unit: DateTimeUnit + + def __post_init__(self) -> None: + if self.scale_factor < 1: + raise ValueError(f"scale_factor must be > 0, got {self.scale_factor}.") + if self.scale_factor >= 2**31: + raise ValueError(f"scale_factor must be < 2147483648, got {self.scale_factor}.") + if self.unit not in get_args(DateTimeUnit): + raise ValueError(f"unit must be one of {get_args(DateTimeUnit)}, got {self.unit!r}.") + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + if cls._check_native_dtype(dtype): + unit, scale_factor = np.datetime_data(dtype.name) + unit = cast("DateTimeUnit", unit) + return cls( + unit=unit, + scale_factor=scale_factor, + endianness=get_endianness_from_numpy_dtype(dtype), + ) + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> BaseTimeDType_co: + # Numpy does not allow creating datetime64 or timedelta64 via + # np.dtypes.{dtype_name}() + # so we use np.dtype with a formatted string. + dtype_string = f"{self._numpy_name}[{self.scale_factor}{self.unit}]" + return np.dtype(dtype_string).newbyteorder(endianness_to_numpy_str(self.endianness)) # type: ignore[return-value] + + @overload # type: ignore[override] + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ... + @overload + def to_json(self, zarr_format: Literal[3]) -> DateTime64JSONV3 | TimeDelta64JSONV3: ... + + def to_json( + self, zarr_format: ZarrFormat + ) -> DTypeConfig_V2[str, None] | DateTime64JSONV3 | TimeDelta64JSONV3: + if zarr_format == 2: + name = self.to_native_dtype().str + return {"name": name, "object_codec_id": None} + elif zarr_format == 3: + return cast( + "DateTime64JSONV3 | TimeDelta64JSONV3", + { + "name": self._zarr_v3_name, + "configuration": {"unit": self.unit, "scale_factor": self.scale_factor}, + }, + ) + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> int: + return datetimelike_to_int(data) # type: ignore[arg-type] + + @property + def item_size(self) -> int: + return 8 + + +@dataclass(frozen=True, kw_only=True, slots=True) +class TimeDelta64(TimeDTypeBase[np.dtypes.TimeDelta64DType, np.timedelta64], HasEndianness): + """ + A wrapper for the ``TimeDelta64`` data type defined in numpy. + Scalars of this type can be created by performing arithmetic with ``DateTime64`` scalars. + Like ``DateTime64``, ``TimeDelta64`` is parametrized by a unit, but unlike ``DateTime64``, the + unit for ``TimeDelta64`` is optional. + """ + + # mypy infers the type of np.dtypes.TimeDelta64DType to be + # "Callable[[Literal['Y', 'M', 'W', 'D'] | Literal['h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as']], Never]" + dtype_cls = np.dtypes.TimeDelta64DType # type: ignore[assignment] + _zarr_v3_name: ClassVar[Literal["numpy.timedelta64"]] = "numpy.timedelta64" + _zarr_v2_names = (">m8", " TypeGuard[DTypeConfig_V2[str, None]]: + if not check_dtype_spec_v2(data): + return False + name = data["name"] + # match m[M], etc + # consider making this a standalone function + if not isinstance(name, str): + return False + if not name.startswith(cls._zarr_v2_names): + return False + if len(name) == 3: + # no unit, and + # we already checked that this string is either m8 + return True + else: + return name[4:-1].endswith(DATETIME_UNIT) and name[-1] == "]" + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSONV3]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and isinstance(data["configuration"], dict) + and set(data["configuration"].keys()) == {"unit", "scale_factor"} + ) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a string " + f"representation of an instance of {cls.dtype_cls}" # type: ignore[has-type] + ) + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + unit = data["configuration"]["unit"] + scale_factor = data["configuration"]["scale_factor"] + return cls(unit=unit, scale_factor=scale_factor) + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a dict " + f"with a 'name' key with the value 'numpy.timedelta64', " + "and a 'configuration' key with a value of a dict with a 'unit' key and a " + "'scale_factor' key" + ) + raise DataTypeValidationError(msg) + + def _check_scalar(self, data: object) -> TypeGuard[TimeDeltaLike]: + if data is None: + return True + return isinstance(data, str | int | bytes | np.timedelta64 | timedelta) + + def _cast_scalar_unchecked(self, data: TimeDeltaLike) -> np.timedelta64: + return self.to_native_dtype().type(data, f"{self.scale_factor}{self.unit}") + + def cast_scalar(self, data: object) -> np.timedelta64: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy timedelta64 scalar." + raise TypeError(msg) + + def default_scalar(self) -> np.timedelta64: + return np.timedelta64("NaT") + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.timedelta64: + if check_json_time(data): + return self.to_native_dtype().type(data, f"{self.scale_factor}{self.unit}") + raise TypeError(f"Invalid type: {data}. Expected an integer.") # pragma: no cover + + +@dataclass(frozen=True, kw_only=True, slots=True) +class DateTime64(TimeDTypeBase[np.dtypes.DateTime64DType, np.datetime64], HasEndianness): + dtype_cls = np.dtypes.DateTime64DType # type: ignore[assignment] + _zarr_v3_name: ClassVar[Literal["numpy.datetime64"]] = "numpy.datetime64" + _zarr_v2_names = (">M8", " TypeGuard[DTypeConfig_V2[str, None]]: + """ + Check that JSON input is a string representation of a NumPy datetime64 data type, like "M8[10s]". This function can be used as a type guard to narrow the type of unknown JSON + input. + """ + if not check_dtype_spec_v2(data): + return False + name = data["name"] + if not isinstance(name, str): + return False + if not name.startswith(cls._zarr_v2_names): + return False + if len(name) == 3: + # no unit, and + # we already checked that this string is either M8 + return True + else: + return name[4:-1].endswith(DATETIME_UNIT) and name[-1] == "]" + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSONV3]: + return ( + isinstance(data, dict) + and set(data.keys()) == {"name", "configuration"} + and data["name"] == cls._zarr_v3_name + and isinstance(data["configuration"], dict) + and set(data["configuration"].keys()) == {"unit", "scale_factor"} + ) + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + if cls._check_json_v2(data): + name = data["name"] + return cls.from_native_dtype(np.dtype(name)) + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a string " + f"representation of an instance of {cls.dtype_cls}" # type: ignore[has-type] + ) + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + if cls._check_json_v3(data): + unit = data["configuration"]["unit"] + scale_factor = data["configuration"]["scale_factor"] + return cls(unit=unit, scale_factor=scale_factor) + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a dict " + f"with a 'name' key with the value 'numpy.datetime64', " + "and a 'configuration' key with a value of a dict with a 'unit' key and a " + "'scale_factor' key" + ) + raise DataTypeValidationError(msg) + + def _check_scalar(self, data: object) -> TypeGuard[DateTimeLike]: + if data is None: + return True + return isinstance(data, str | int | bytes | np.datetime64 | datetime) + + def _cast_scalar_unchecked(self, data: DateTimeLike) -> np.datetime64: + return self.to_native_dtype().type(data, f"{self.scale_factor}{self.unit}") + + def cast_scalar(self, data: object) -> np.datetime64: + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = f"Cannot convert object with type {type(data)} to a numpy datetime scalar." + raise TypeError(msg) + + def default_scalar(self) -> np.datetime64: + return np.datetime64("NaT") + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.datetime64: + if check_json_time(data): + return self._cast_scalar_unchecked(data) + raise TypeError(f"Invalid type: {data}. Expected an integer.") # pragma: no cover diff --git a/src/zarr/core/dtype/registry.py b/src/zarr/core/dtype/registry.py new file mode 100644 index 0000000000..1d2a97a90a --- /dev/null +++ b/src/zarr/core/dtype/registry.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Self + +import numpy as np + +from zarr.core.dtype.common import ( + DataTypeValidationError, + DTypeJSON, +) + +if TYPE_CHECKING: + from importlib.metadata import EntryPoint + + from zarr.core.common import ZarrFormat + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + + +# This class is different from the other registry classes, which inherit from +# dict. IMO it's simpler to just do a dataclass. But long-term we should +# have just 1 registry class in use. +@dataclass(frozen=True, kw_only=True) +class DataTypeRegistry: + contents: dict[str, type[ZDType[TBaseDType, TBaseScalar]]] = field( + default_factory=dict, init=False + ) + + lazy_load_list: list[EntryPoint] = field(default_factory=list, init=False) + + def lazy_load(self) -> None: + for e in self.lazy_load_list: + self.register(e.load()._zarr_v3_name, e.load()) + + self.lazy_load_list.clear() + + def register(self: Self, key: str, cls: type[ZDType[TBaseDType, TBaseScalar]]) -> None: + # don't register the same dtype twice + if key not in self.contents or self.contents[key] != cls: + self.contents[key] = cls + + def unregister(self, key: str) -> None: + """Unregister a data type by its key.""" + if key in self.contents: + del self.contents[key] + else: + raise KeyError(f"Data type '{key}' not found in registry.") + + def get(self, key: str) -> type[ZDType[TBaseDType, TBaseScalar]]: + return self.contents[key] + + def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]: + if dtype == np.dtype("O"): + msg = ( + f"Zarr data type resolution from {dtype} failed. " + 'Attempted to resolve a zarr data type from a numpy "Object" data type, which is ' + 'ambiguous, as multiple zarr data types can be represented by the numpy "Object" ' + "data type. " + "In this case you should construct your array by providing a specific Zarr data " + 'type. For a list of Zarr data types that are compatible with the numpy "Object"' + "data type, see https://github.com/zarr-developers/zarr-python/issues/3117" + ) + raise ValueError(msg) + matched: list[ZDType[TBaseDType, TBaseScalar]] = [] + for val in self.contents.values(): + with contextlib.suppress(DataTypeValidationError): + matched.append(val.from_native_dtype(dtype)) + if len(matched) == 1: + return matched[0] + elif len(matched) > 1: + msg = ( + f"Zarr data type resolution from {dtype} failed. " + f"Multiple data type wrappers found that match dtype '{dtype}': {matched}. " + "You should unregister one of these data types, or avoid Zarr data type inference " + "entirely by providing a specific Zarr data type when creating your array." + "For more information, see https://github.com/zarr-developers/zarr-python/issues/3117" + ) + raise ValueError(msg) + raise ValueError(f"No Zarr data type found that matches dtype '{dtype!r}'") + + def match_json( + self, data: DTypeJSON, *, zarr_format: ZarrFormat + ) -> ZDType[TBaseDType, TBaseScalar]: + for val in self.contents.values(): + try: + return val.from_json(data, zarr_format=zarr_format) + except DataTypeValidationError: + pass + raise ValueError(f"No Zarr data type found that matches {data!r}") diff --git a/src/zarr/core/dtype/wrapper.py b/src/zarr/core/dtype/wrapper.py new file mode 100644 index 0000000000..e974712e38 --- /dev/null +++ b/src/zarr/core/dtype/wrapper.py @@ -0,0 +1,297 @@ +""" +Wrapper for native array data types. + +The ``ZDType`` class is an abstract base class for wrapping native array data types, e.g. NumPy dtypes. +``ZDType`` provides a common interface for working with data types in a way that is independent of the +underlying data type system. + +The wrapper class encapsulates a native data type. Instances of the class can be created from a +native data type instance, and a native data type instance can be created from an instance of the +wrapper class. + +The wrapper class is responsible for: +- Serializing and deserializing a native data type to Zarr V2 or Zarr V3 metadata. + This ensures that the data type can be properly stored and retrieved from array metadata. +- Serializing and deserializing scalar values to Zarr V2 or Zarr V3 metadata. This is important for + storing a fill value for an array in a manner that is valid for the data type. + +You can add support for a new data type in Zarr by subclassing ``ZDType`` wrapper class and adapt its methods +to support your native data type. The wrapper class must be added to a data type registry +(defined elsewhere) before array creation routines or array reading routines can use your new data +type. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, + Literal, + Self, + TypeGuard, + TypeVar, + overload, +) + +import numpy as np + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + from zarr.core.dtype.common import DTypeJSON, DTypeSpec_V2, DTypeSpec_V3 + +# This the upper bound for the scalar types we support. It's numpy scalars + str, +# because the new variable-length string dtype in numpy does not have a corresponding scalar type +TBaseScalar = np.generic | str | bytes +# This is the bound for the dtypes that we support. If we support non-numpy dtypes, +# then this bound will need to be widened. +TBaseDType = np.dtype[np.generic] + +# These two type parameters are covariant because we want +# x : ZDType[BaseDType, BaseScalar] = ZDType[SubDType, SubScalar] +# to type check +TScalar_co = TypeVar("TScalar_co", bound=TBaseScalar, covariant=True) +TDType_co = TypeVar("TDType_co", bound=TBaseDType, covariant=True) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class ZDType(Generic[TDType_co, TScalar_co], ABC): + """ + Abstract base class for wrapping native array data types, e.g. numpy dtypes + + Attributes + ---------- + dtype_cls : ClassVar[type[TDType]] + The wrapped dtype class. This is a class variable. + _zarr_v3_name : ClassVar[str] + The name given to the data type by a Zarr v3 data type specification. This is a + class variable, and it should generally be unique across different data types. + """ + + # this class will create a native data type + # mypy currently disallows class variables to contain type parameters + # but it seems OK for us to use it here: + # https://github.com/python/typing/discussions/1424#discussioncomment-7989934 + dtype_cls: ClassVar[type[TDType_co]] # type: ignore[misc] + _zarr_v3_name: ClassVar[str] + + @classmethod + def _check_native_dtype(cls: type[Self], dtype: TBaseDType) -> TypeGuard[TDType_co]: + """ + Check that a native data type matches the dtype_cls class attribute. Used as a type guard. + + Parameters + ---------- + dtype : TDType + The dtype to check. + + Returns + ------- + Bool + True if the dtype matches, False otherwise. + """ + return type(dtype) is cls.dtype_cls + + @classmethod + @abstractmethod + def from_native_dtype(cls: type[Self], dtype: TBaseDType) -> Self: + """ + Create a ZDType instance from a native data type. The default implementation first performs + a type check via ``cls._check_native_dtype``. If that type check succeeds, the ZDType class + instance is created. + + This method is used when taking a user-provided native data type, like a NumPy data type, + and creating the corresponding ZDType instance from them. + + Parameters + ---------- + dtype : TDType + The native data type object to wrap. + + Returns + ------- + Self + The ZDType that wraps the native data type. + + Raises + ------ + TypeError + If the native data type is not consistent with the wrapped data type. + """ + ... + + @abstractmethod + def to_native_dtype(self: Self) -> TDType_co: + """ + Return an instance of the wrapped data type. This operation inverts ``from_native_dtype``. + + Returns + ------- + TDType + The native data type wrapped by this ZDType. + """ + ... + + @classmethod + @abstractmethod + def _from_json_v2(cls: type[Self], data: DTypeJSON) -> Self: ... + + @classmethod + @abstractmethod + def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self: ... + + @classmethod + def from_json(cls: type[Self], data: DTypeJSON, *, zarr_format: ZarrFormat) -> Self: + """ + Create an instance of this ZDType from JSON data. + + Parameters + ---------- + data : DTypeJSON + The JSON representation of the data type. The type annotation includes + Mapping[str, object] to accommodate typed dictionaries. + + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + Self + The wrapped data type. + """ + if zarr_format == 2: + return cls._from_json_v2(data) + if zarr_format == 3: + return cls._from_json_v3(data) + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + @overload + def to_json(self, zarr_format: Literal[2]) -> DTypeSpec_V2: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... + + @abstractmethod + def to_json(self, zarr_format: ZarrFormat) -> DTypeSpec_V2 | DTypeSpec_V3: + """ + Serialize this ZDType to JSON. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + DTypeJSON_V2 | DTypeJSON_V3 + The JSON-serializable representation of the wrapped data type + """ + ... + + @abstractmethod + def _check_scalar(self, data: object) -> bool: + """ + Check that an python object is a valid scalar value for the wrapped data type. + + Parameters + ---------- + data : object + A value to check. + + Returns + ------- + Bool + True if the object is valid, False otherwise. + """ + ... + + @abstractmethod + def cast_scalar(self, data: object) -> TScalar_co: + """ + Cast a python object to the wrapped scalar type. + The type of the provided scalar is first checked for compatibility. + If it's incompatible with the associated scalar type, a ``TypeError`` will be raised. + + Parameters + ---------- + data : object + The python object to cast. + + Returns + ------- + TScalar + The cast value. + """ + + @abstractmethod + def default_scalar(self) -> TScalar_co: + """ + Get the default scalar value for the wrapped data type. This is a method, rather than an + attribute, because the default value for some data types depends on parameters that are + not known until a concrete data type is wrapped. For example, data types parametrized by a + length like fixed-length strings or bytes will generate scalars consistent with that length. + + Returns + ------- + TScalar + The default value for this data type. + """ + ... + + @abstractmethod + def from_json_scalar(self: Self, data: JSON, *, zarr_format: ZarrFormat) -> TScalar_co: + """ + Read a JSON-serializable value as a scalar. + + Parameters + ---------- + data : JSON + A JSON representation of a scalar value. + zarr_format : ZarrFormat + The zarr format version. This is specified because the JSON serialization of scalars + differs between Zarr V2 and Zarr V3. + + Returns + ------- + TScalar + The deserialized scalar value. + """ + ... + + @abstractmethod + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> JSON: + """ + Serialize a python object to the JSON representation of a scalar. The value will first be + cast to the scalar type associated with this ZDType, then serialized to JSON. + + Parameters + ---------- + data : object + The value to convert. + zarr_format : ZarrFormat + The zarr format version. This is specified because the JSON serialization of scalars + differs between Zarr V2 and Zarr V3. + + Returns + ------- + JSON + The JSON-serialized scalar. + """ + ... + + +def scalar_failed_type_check_msg( + cls_instance: ZDType[TBaseDType, TBaseScalar], bad_scalar: object +) -> str: + """ + Generate an error message reporting that a particular value failed a type check when attempting + to cast that value to a scalar. + """ + return ( + f"The value {bad_scalar!r} failed a type check. " + f"It cannot be safely cast to a scalar compatible with {cls_instance}. " + f"Consult the documentation for {cls_instance} to determine the possible values that can " + "be cast to scalars of the wrapped data type." + ) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 3ce46ec97b..b50bce3aef 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import base64 import itertools import json import logging @@ -50,7 +49,6 @@ ) from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata -from zarr.core.metadata.v3 import V3JsonEncoder, _replace_special_floats from zarr.core.sync import SyncMixin, sync from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import StoreLike, StorePath @@ -337,7 +335,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: if self.zarr_format == 3: return { ZARR_JSON: prototype.buffer.from_bytes( - json.dumps(_replace_special_floats(self.to_dict()), cls=V3JsonEncoder).encode() + json.dumps(self.to_dict(), indent=json_indent, allow_nan=False).encode() ) } else: @@ -346,7 +344,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: json.dumps({"zarr_format": self.zarr_format}, indent=json_indent).encode() ), ZATTRS_JSON: prototype.buffer.from_bytes( - json.dumps(self.attributes, indent=json_indent).encode() + json.dumps(self.attributes, indent=json_indent, allow_nan=False).encode() ), } if self.consolidated_metadata: @@ -357,16 +355,10 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: consolidated_metadata = self.consolidated_metadata.to_dict()["metadata"] assert isinstance(consolidated_metadata, dict) for k, v in consolidated_metadata.items(): - attrs = v.pop("attributes", None) - d[f"{k}/{ZATTRS_JSON}"] = _replace_special_floats(attrs) + attrs = v.pop("attributes", {}) + d[f"{k}/{ZATTRS_JSON}"] = attrs if "shape" in v: # it's an array - if isinstance(v.get("fill_value", None), np.void): - v["fill_value"] = base64.standard_b64encode( - cast("bytes", v["fill_value"]) - ).decode("ascii") - else: - v = _replace_special_floats(v) d[f"{k}/{ZARRAY_JSON}"] = v else: d[f"{k}/{ZGROUP_JSON}"] = { @@ -380,8 +372,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: items[ZMETADATA_V2_JSON] = prototype.buffer.from_bytes( json.dumps( - {"metadata": d, "zarr_consolidated_format": 1}, - cls=V3JsonEncoder, + {"metadata": d, "zarr_consolidated_format": 1}, allow_nan=False ).encode() ) @@ -631,6 +622,7 @@ def _from_bytes_v2( consolidated_metadata[path].update(v) else: raise ValueError(f"Invalid file type '{kind}' at path '{path}") + group_metadata["consolidated_metadata"] = { "metadata": dict(consolidated_metadata), "kind": "inline", diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index a8f4f4abb4..3ac75e0418 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -1,15 +1,16 @@ from __future__ import annotations -import base64 import warnings from collections.abc import Iterable, Sequence -from enum import Enum from functools import cached_property from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast import numcodecs.abc from zarr.abc.metadata import Metadata +from zarr.core.chunk_grids import RegularChunkGrid +from zarr.core.dtype import get_data_type_from_json +from zarr.core.dtype.common import OBJECT_CODEC_IDS, DTypeSpec_V2 if TYPE_CHECKING: from typing import Literal, Self @@ -18,18 +19,29 @@ from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import ChunkCoords + from zarr.core.dtype.wrapper import ( + TBaseDType, + TBaseScalar, + TDType_co, + TScalar_co, + ZDType, + ) import json -import numbers from dataclasses import dataclass, field, fields, replace import numcodecs import numpy as np from zarr.core.array_spec import ArrayConfig, ArraySpec -from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import parse_separator -from zarr.core.common import JSON, ZARRAY_JSON, ZATTRS_JSON, MemoryOrder, parse_shapelike +from zarr.core.common import ( + JSON, + ZARRAY_JSON, + ZATTRS_JSON, + MemoryOrder, + parse_shapelike, +) from zarr.core.config import config, parse_indexing_order from zarr.core.metadata.common import parse_attributes @@ -51,8 +63,8 @@ class ArrayV2MetadataDict(TypedDict): class ArrayV2Metadata(Metadata): shape: ChunkCoords chunks: ChunkCoords - dtype: np.dtype[Any] - fill_value: int | float | str | bytes | None = 0 + dtype: ZDType[TBaseDType, TBaseScalar] + fill_value: int | float | str | bytes | None = None order: MemoryOrder = "C" filters: tuple[numcodecs.abc.Codec, ...] | None = None dimension_separator: Literal[".", "/"] = "." @@ -64,7 +76,7 @@ def __init__( self, *, shape: ChunkCoords, - dtype: npt.DTypeLike, + dtype: ZDType[TDType_co, TScalar_co], chunks: ChunkCoords, fill_value: Any, order: MemoryOrder, @@ -77,18 +89,20 @@ def __init__( Metadata for a Zarr format 2 array. """ shape_parsed = parse_shapelike(shape) - dtype_parsed = parse_dtype(dtype) chunks_parsed = parse_shapelike(chunks) - compressor_parsed = parse_compressor(compressor) order_parsed = parse_indexing_order(order) dimension_separator_parsed = parse_separator(dimension_separator) filters_parsed = parse_filters(filters) - fill_value_parsed = parse_fill_value(fill_value, dtype=dtype_parsed) + fill_value_parsed: TBaseScalar | None + if fill_value is not None: + fill_value_parsed = dtype.cast_scalar(fill_value) + else: + fill_value_parsed = fill_value attributes_parsed = parse_attributes(attributes) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "dtype", dtype_parsed) + object.__setattr__(self, "dtype", dtype) object.__setattr__(self, "chunks", chunks_parsed) object.__setattr__(self, "compressor", compressor_parsed) object.__setattr__(self, "order", order_parsed) @@ -113,52 +127,12 @@ def shards(self) -> ChunkCoords | None: return None def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: - def _json_convert( - o: Any, - ) -> Any: - if isinstance(o, np.dtype): - if o.fields is None: - return o.str - else: - return o.descr - if isinstance(o, numcodecs.abc.Codec): - codec_config = o.get_config() - - # Hotfix for https://github.com/zarr-developers/zarr-python/issues/2647 - if codec_config["id"] == "zstd" and not codec_config.get("checksum", False): - codec_config.pop("checksum", None) - - return codec_config - if np.isscalar(o): - out: Any - if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): - # https://github.com/zarr-developers/zarr-python/issues/2119 - # `.item()` on a datetime type might or might not return an - # integer, depending on the value. - # Explicitly cast to an int first, and then grab .item() - out = o.view("i8").item() - else: - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return [out.real, out.imag] - return out - if isinstance(o, Enum): - return o.name - raise TypeError - zarray_dict = self.to_dict() - zarray_dict["fill_value"] = _serialize_fill_value(self.fill_value, self.dtype) zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { ZARRAY_JSON: prototype.buffer.from_bytes( - json.dumps( - zarray_dict, default=_json_convert, indent=json_indent, allow_nan=False - ).encode() + json.dumps(zarray_dict, indent=json_indent, allow_nan=False).encode() ), ZATTRS_JSON: prototype.buffer.from_bytes( json.dumps(zattrs_dict, indent=json_indent, allow_nan=False).encode() @@ -172,8 +146,33 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: # Check that the zarr_format attribute is correct. _ = parse_zarr_format(_data.pop("zarr_format")) - # zarr v2 allowed arbitrary keys in the metadata. - # Filter the keys to only those expected by the constructor. + # To resolve a numpy object dtype array, we need to search for an object codec, + # which could be in filters or as a compressor. + # we will reference a hard-coded collection of object codec ids for this search. + + _filters, _compressor = (data.get("filters"), data.get("compressor")) + if _filters is not None: + _filters = cast("tuple[dict[str, JSON], ...]", _filters) + object_codec_id = get_object_codec_id(tuple(_filters) + (_compressor,)) + else: + object_codec_id = get_object_codec_id((_compressor,)) + # we add a layer of indirection here around the dtype attribute of the array metadata + # because we also need to know the object codec id, if any, to resolve the data type + dtype_spec: DTypeSpec_V2 = { + "name": data["dtype"], + "object_codec_id": object_codec_id, + } + dtype = get_data_type_from_json(dtype_spec, zarr_format=2) + + _data["dtype"] = dtype + fill_value_encoded = _data.get("fill_value") + if fill_value_encoded is not None: + fill_value = dtype.from_json_scalar(fill_value_encoded, zarr_format=2) + _data["fill_value"] = fill_value + + # zarr v2 allowed arbitrary keys here. + # We don't want the ArrayV2Metadata constructor to fail just because someone put an + # extra key in the metadata. expected = {x.name for x in fields(cls)} expected |= {"dtype", "chunks"} @@ -198,16 +197,34 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: def to_dict(self) -> dict[str, JSON]: zarray_dict = super().to_dict() + if isinstance(zarray_dict["compressor"], numcodecs.abc.Codec): + codec_config = zarray_dict["compressor"].get_config() + # Hotfix for https://github.com/zarr-developers/zarr-python/issues/2647 + if codec_config["id"] == "zstd" and not codec_config.get("checksum", False): + codec_config.pop("checksum") + zarray_dict["compressor"] = codec_config + + if zarray_dict["filters"] is not None: + raw_filters = zarray_dict["filters"] + # TODO: remove this when we can stratically type the output JSON data structure + # entirely + if not isinstance(raw_filters, list | tuple): + raise TypeError("Invalid type for filters. Expected a list or tuple.") + new_filters = [] + for f in raw_filters: + if isinstance(f, numcodecs.abc.Codec): + new_filters.append(f.get_config()) + else: + new_filters.append(f) + zarray_dict["filters"] = new_filters - _ = zarray_dict.pop("dtype") - dtype_json: JSON - # In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string - dtype_descr = self.dtype.descr - if self.dtype.kind == "V" and dtype_descr[0][0] != "" and len(dtype_descr) != 0: - dtype_json = tuple(self.dtype.descr) - else: - dtype_json = self.dtype.str - zarray_dict["dtype"] = dtype_json + # serialize the fill value after dtype-specific JSON encoding + if self.fill_value is not None: + fill_value = self.dtype.to_json_scalar(self.fill_value, zarr_format=2) + zarray_dict["fill_value"] = fill_value + + # pull the "name" attribute out of the dtype spec returned by self.dtype.to_json + zarray_dict["dtype"] = self.dtype.to_json(zarr_format=2)["name"] return zarray_dict @@ -296,178 +313,19 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: return data -def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: - """Handle structured dtype/fill value pairs""" - try: - if isinstance(fill_value, list): - return np.array([tuple(fill_value)], dtype=dtype)[0] - elif isinstance(fill_value, tuple): - return np.array([fill_value], dtype=dtype)[0] - elif isinstance(fill_value, bytes): - return np.frombuffer(fill_value, dtype=dtype)[0] - elif isinstance(fill_value, str): - decoded = base64.standard_b64decode(fill_value) - return np.frombuffer(decoded, dtype=dtype)[0] - else: - return np.array(fill_value, dtype=dtype)[()] - except Exception as e: - raise ValueError(f"Fill_value {fill_value} is not valid for dtype {dtype}.") from e - - -def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: +def get_object_codec_id(maybe_object_codecs: Sequence[JSON]) -> str | None: """ - Parse a potential fill value into a value that is compatible with the provided dtype. - - Parameters - ---------- - fill_value : Any - A potential fill value. - dtype : np.dtype[Any] - A numpy dtype. - - Returns - ------- - An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) + Inspect a sequence of codecs / filters for an "object codec", i.e. a codec + that can serialize object arrays to contiguous bytes. Zarr python + maintains a hard-coded set of object codec ids. If any element from the input + has an id that matches one of the hard-coded object codec ids, that id + is returned immediately. """ - - if fill_value is None or dtype.hasobject: - pass - elif dtype.fields is not None: - # the dtype is structured (has multiple fields), so the fill_value might be a - # compound value (e.g., a tuple or dict) that needs field-wise processing. - # We use parse_structured_fill_value to correctly convert each component. - fill_value = _parse_structured_fill_value(fill_value, dtype) - elif not isinstance(fill_value, np.void) and fill_value == 0: - # this should be compatible across numpy versions for any array type, including - # structured arrays - fill_value = np.zeros((), dtype=dtype)[()] - elif dtype.kind == "U": - # special case unicode because of encoding issues on Windows if passed through numpy - # https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713 - - if not isinstance(fill_value, str): - raise ValueError( - f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string" - ) - elif dtype.kind in "SV" and isinstance(fill_value, str): - fill_value = base64.standard_b64decode(fill_value) - elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: - complex_val = complex(float(fill_value[0]), float(fill_value[1])) - fill_value = np.array(complex_val, dtype=dtype)[()] - else: - try: - if isinstance(fill_value, bytes) and dtype.kind == "V": - # special case for numpy 1.14 compatibility - fill_value = np.array(fill_value, dtype=dtype.str).view(dtype)[()] - else: - fill_value = np.array(fill_value, dtype=dtype)[()] - - except Exception as e: - msg = f"Fill_value {fill_value} is not valid for dtype {dtype}." - raise ValueError(msg) from e - - return fill_value - - -def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON: - serialized: JSON - - if fill_value is None: - serialized = None - elif dtype.kind in "SV": - # There's a relationship between dtype and fill_value - # that mypy isn't aware of. The fact that we have S or V dtype here - # means we should have a bytes-type fill_value. - serialized = base64.standard_b64encode(cast("bytes", fill_value)).decode("ascii") - elif isinstance(fill_value, np.datetime64): - serialized = np.datetime_as_string(fill_value) - elif isinstance(fill_value, numbers.Integral): - serialized = int(fill_value) - elif isinstance(fill_value, numbers.Real): - float_fv = float(fill_value) - if np.isnan(float_fv): - serialized = "NaN" - elif np.isinf(float_fv): - serialized = "Infinity" if float_fv > 0 else "-Infinity" - else: - serialized = float_fv - elif isinstance(fill_value, numbers.Complex): - serialized = [ - _serialize_fill_value(fill_value.real, dtype), - _serialize_fill_value(fill_value.imag, dtype), - ] - else: - serialized = fill_value - - return serialized - - -def _default_fill_value(dtype: np.dtype[Any]) -> Any: - """ - Get the default fill value for a type. - - Notes - ----- - This differs from :func:`parse_fill_value`, which parses a fill value - stored in the Array metadata into an in-memory value. This only gives - the default fill value for some type. - - This is useful for reading Zarr format 2 arrays, which allow the fill - value to be unspecified. - """ - if dtype.kind == "S": - return b"" - elif dtype.kind in "UO": - return "" - elif dtype.kind in "Mm": - return dtype.type("nat") - elif dtype.kind == "V": - if dtype.fields is not None: - default = tuple(_default_fill_value(field[0]) for field in dtype.fields.values()) - return np.array([default], dtype=dtype) - else: - return np.zeros(1, dtype=dtype) - else: - return dtype.type(0) - - -def _default_compressor( - dtype: np.dtype[Any], -) -> dict[str, JSON] | None: - """Get the default filters and compressor for a dtype. - - https://numpy.org/doc/2.1/reference/generated/numpy.dtype.kind.html - """ - default_compressor = config.get("array.v2_default_compressor") - if dtype.kind in "biufcmM": - dtype_key = "numeric" - elif dtype.kind in "U": - dtype_key = "string" - elif dtype.kind in "OSV": - dtype_key = "bytes" - else: - raise ValueError(f"Unsupported dtype kind {dtype.kind}") - - return cast("dict[str, JSON] | None", default_compressor.get(dtype_key, None)) - - -def _default_filters( - dtype: np.dtype[Any], -) -> list[dict[str, JSON]] | None: - """Get the default filters and compressor for a dtype. - - https://numpy.org/doc/2.1/reference/generated/numpy.dtype.kind.html - """ - default_filters = config.get("array.v2_default_filters") - if dtype.kind in "biufcmM": - dtype_key = "numeric" - elif dtype.kind in "U": - dtype_key = "string" - elif dtype.kind in "OS": - dtype_key = "bytes" - elif dtype.kind == "V": - dtype_key = "raw" - else: - raise ValueError(f"Unsupported dtype kind {dtype.kind}") - - return cast("list[dict[str, JSON]] | None", default_filters.get(dtype_key, None)) + object_codec_id = None + for maybe_object_codec in maybe_object_codecs: + if ( + isinstance(maybe_object_codec, dict) + and maybe_object_codec.get("id") in OBJECT_CODEC_IDS + ): + return cast("str", maybe_object_codec["id"]) + return object_codec_id diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index dcbf44f89b..84872d3dbd 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,28 +1,25 @@ from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, TypedDict, overload +from typing import TYPE_CHECKING, TypedDict from zarr.abc.metadata import Metadata from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.dtype import VariableLengthUTF8, ZDType, get_data_type_from_json +from zarr.core.dtype.common import check_dtype_spec_v3 if TYPE_CHECKING: - from collections.abc import Callable from typing import Self from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import JSON, ChunkCoords + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar + import json -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from dataclasses import dataclass, field, replace -from enum import Enum -from typing import Any, Literal, cast - -import numcodecs.abc -import numpy as np -import numpy.typing as npt +from typing import Any, Literal from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArrayConfig, ArraySpec @@ -38,20 +35,9 @@ ) from zarr.core.config import config from zarr.core.metadata.common import parse_attributes -from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING -from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE from zarr.errors import MetadataValidationError, NodeTypeValidationError from zarr.registry import get_codec_class -DEFAULT_DTYPE = "float64" - -# Keep in sync with _replace_special_floats -SPECIAL_FLOATS_ENCODED = { - "Infinity": np.inf, - "-Infinity": -np.inf, - "NaN": np.nan, -} - def parse_zarr_format(data: object) -> Literal[3]: if data == 3: @@ -94,7 +80,7 @@ def validate_array_bytes_codec(codecs: tuple[Codec, ...]) -> ArrayBytesCodec: return abcs[0] -def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None: +def validate_codecs(codecs: tuple[Codec, ...], dtype: ZDType[TBaseDType, TBaseScalar]) -> None: """Check that the codecs are valid for the given dtype""" from zarr.codecs.sharding import ShardingCodec @@ -107,14 +93,11 @@ def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None: # we need to have special codecs if we are decoding vlen strings or bytestrings # TODO: use codec ID instead of class name codec_class_name = abc.__class__.__name__ - if dtype == DataType.string and not codec_class_name == "VLenUTF8Codec": + # TODO: Fix typing here + if isinstance(dtype, VariableLengthUTF8) and not codec_class_name == "VLenUTF8Codec": # type: ignore[unreachable] raise ValueError( f"For string dtype, ArrayBytesCodec must be `VLenUTF8Codec`, got `{codec_class_name}`." ) - if dtype == DataType.bytes and not codec_class_name == "VLenBytesCodec": - raise ValueError( - f"For bytes dtype, ArrayBytesCodec must be `VLenBytesCodec`, got `{codec_class_name}`." - ) def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: @@ -144,87 +127,6 @@ def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]: ) -class V3JsonEncoder(json.JSONEncoder): - def __init__( - self, - *, - skipkeys: bool = False, - ensure_ascii: bool = True, - check_circular: bool = True, - allow_nan: bool = True, - sort_keys: bool = False, - indent: int | None = None, - separators: tuple[str, str] | None = None, - default: Callable[[object], object] | None = None, - ) -> None: - if indent is None: - indent = config.get("json_indent") - super().__init__( - skipkeys=skipkeys, - ensure_ascii=ensure_ascii, - check_circular=check_circular, - allow_nan=allow_nan, - sort_keys=sort_keys, - indent=indent, - separators=separators, - default=default, - ) - - def default(self, o: object) -> Any: - if isinstance(o, np.dtype): - return str(o) - if np.isscalar(o): - out: Any - if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): - # https://github.com/zarr-developers/zarr-python/issues/2119 - # `.item()` on a datetime type might or might not return an - # integer, depending on the value. - # Explicitly cast to an int first, and then grab .item() - out = o.view("i8").item() - else: - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return _replace_special_floats([out.real, out.imag]) - elif np.isnan(out): - return "NaN" - elif np.isinf(out): - return "Infinity" if out > 0 else "-Infinity" - return out - elif isinstance(o, Enum): - return o.name - # this serializes numcodecs compressors - # todo: implement to_dict for codecs - elif isinstance(o, numcodecs.abc.Codec): - config: dict[str, Any] = o.get_config() - return config - else: - return super().default(o) - - -def _replace_special_floats(obj: object) -> Any: - """Helper function to replace NaN/Inf/-Inf values with special strings - - Note: this cannot be done in the V3JsonEncoder because Python's `json.dumps` optimistically - converts NaN/Inf values to special types outside of the encoding step. - """ - if isinstance(obj, float): - if np.isnan(obj): - return "NaN" - elif np.isinf(obj): - return "Infinity" if obj > 0 else "-Infinity" - elif isinstance(obj, dict): - # Recursively replace in dictionaries - return {k: _replace_special_floats(v) for k, v in obj.items()} - elif isinstance(obj, list): - # Recursively replace in lists - return [_replace_special_floats(item) for item in obj] - return obj - - class ArrayV3MetadataDict(TypedDict): """ A typed dictionary model for zarr v3 metadata. @@ -237,7 +139,7 @@ class ArrayV3MetadataDict(TypedDict): @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(Metadata): shape: ChunkCoords - data_type: DataType + data_type: ZDType[TBaseDType, TBaseScalar] chunk_grid: ChunkGrid chunk_key_encoding: ChunkKeyEncoding fill_value: Any @@ -252,10 +154,10 @@ def __init__( self, *, shape: Iterable[int], - data_type: npt.DTypeLike | DataType, + data_type: ZDType[TBaseDType, TBaseScalar], chunk_grid: dict[str, JSON] | ChunkGrid, chunk_key_encoding: ChunkKeyEncodingLike, - fill_value: Any, + fill_value: object, codecs: Iterable[Codec | dict[str, JSON]], attributes: dict[str, JSON] | None, dimension_names: DimensionNames, @@ -264,33 +166,29 @@ def __init__( """ Because the class is a frozen dataclass, we set attributes using object.__setattr__ """ + shape_parsed = parse_shapelike(shape) - data_type_parsed = DataType.parse(data_type) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - if fill_value is None: - fill_value = default_fill_value(data_type_parsed) - # we pass a string here rather than an enum to make mypy happy - fill_value_parsed = parse_fill_value( - fill_value, dtype=cast("ALL_DTYPES", data_type_parsed.value) - ) + # Note: relying on a type method is numpy-specific + fill_value_parsed = data_type.cast_scalar(fill_value) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type_parsed.to_numpy(), + dtype=data_type, fill_value=fill_value_parsed, config=ArrayConfig.from_dict({}), # TODO: config is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = tuple(c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial) - validate_codecs(codecs_parsed_partial, data_type_parsed) + validate_codecs(codecs_parsed_partial, data_type) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "data_type", data_type_parsed) + object.__setattr__(self, "data_type", data_type) object.__setattr__(self, "chunk_grid", chunk_grid_parsed) object.__setattr__(self, "chunk_key_encoding", chunk_key_encoding_parsed) object.__setattr__(self, "codecs", codecs_parsed) @@ -315,19 +213,16 @@ def _validate_metadata(self) -> None: if self.fill_value is None: raise ValueError("`fill_value` is required.") for codec in self.codecs: - codec.validate( - shape=self.shape, dtype=self.data_type.to_numpy(), chunk_grid=self.chunk_grid - ) - - @property - def dtype(self) -> np.dtype[Any]: - """Interpret Zarr dtype as NumPy dtype""" - return self.data_type.to_numpy() + codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid) @property def ndim(self) -> int: return len(self.shape) + @property + def dtype(self) -> ZDType[TBaseDType, TBaseScalar]: + return self.data_type + @property def chunks(self) -> ChunkCoords: if isinstance(self.chunk_grid, RegularChunkGrid): @@ -389,8 +284,13 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: - d = _replace_special_floats(self.to_dict()) - return {ZARR_JSON: prototype.buffer.from_bytes(json.dumps(d, cls=V3JsonEncoder).encode())} + json_indent = config.get("json_indent") + d = self.to_dict() + return { + ZARR_JSON: prototype.buffer.from_bytes( + json.dumps(d, allow_nan=False, indent=json_indent).encode() + ) + } @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: @@ -402,18 +302,31 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: # check that the node_type attribute is correct _ = parse_node_type_array(_data.pop("node_type")) - # check that the data_type attribute is valid - data_type = DataType.parse(_data.pop("data_type")) + data_type_json = _data.pop("data_type") + if not check_dtype_spec_v3(data_type_json): + raise ValueError(f"Invalid data_type: {data_type_json!r}") + data_type = get_data_type_from_json(data_type_json, zarr_format=3) + + # check that the fill value is consistent with the data type + try: + fill = _data.pop("fill_value") + fill_value_parsed = data_type.from_json_scalar(fill, zarr_format=3) + except ValueError as e: + raise TypeError(f"Invalid fill_value: {fill!r}") from e # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) + # attributes key is optional, normalize missing to `None` _data["attributes"] = _data.pop("attributes", None) - return cls(**_data, data_type=data_type) # type: ignore[arg-type] + + return cls(**_data, fill_value=fill_value_parsed, data_type=data_type) # type: ignore[arg-type] def to_dict(self) -> dict[str, JSON]: out_dict = super().to_dict() - + out_dict["fill_value"] = self.data_type.to_json_scalar( + self.fill_value, zarr_format=self.zarr_format + ) if not isinstance(out_dict, dict): raise TypeError(f"Expected dict. Got {type(out_dict)}.") @@ -421,6 +334,15 @@ def to_dict(self) -> dict[str, JSON]: # the metadata document if out_dict["dimension_names"] is None: out_dict.pop("dimension_names") + + # TODO: replace the `to_dict` / `from_dict` on the `Metadata`` class with + # to_json, from_json, and have ZDType inherit from `Metadata` + # until then, we have this hack here, which relies on the fact that to_dict will pass through + # any non-`Metadata` fields as-is. + dtype_meta = out_dict["data_type"] + if isinstance(dtype_meta, ZDType): + out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable] + return out_dict def update_shape(self, shape: ChunkCoords) -> Self: @@ -428,299 +350,3 @@ def update_shape(self, shape: ChunkCoords) -> Self: def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) - - -# enum Literals can't be used in typing, so we have to restate all of the V3 dtypes as types -# https://github.com/python/typing/issues/781 - -BOOL_DTYPE = Literal["bool"] -BOOL = np.bool_ -INTEGER_DTYPE = Literal["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] -INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 -FLOAT_DTYPE = Literal["float16", "float32", "float64"] -FLOAT = np.float16 | np.float32 | np.float64 -COMPLEX_DTYPE = Literal["complex64", "complex128"] -COMPLEX = np.complex64 | np.complex128 -STRING_DTYPE = Literal["string"] -STRING = np.str_ -BYTES_DTYPE = Literal["bytes"] -BYTES = np.bytes_ - -ALL_DTYPES = BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: BOOL_DTYPE, -) -> BOOL: ... - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: INTEGER_DTYPE, -) -> INTEGER: ... - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: FLOAT_DTYPE, -) -> FLOAT: ... - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: COMPLEX_DTYPE, -) -> COMPLEX: ... - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: STRING_DTYPE, -) -> STRING: ... - - -@overload -def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: BYTES_DTYPE, -) -> BYTES: ... - - -def parse_fill_value( - fill_value: Any, - dtype: ALL_DTYPES, -) -> Any: - """ - Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. - If `fill_value` is `None`, then this function will return the result of casting the value 0 - to the provided data type. Otherwise, `fill_value` will be cast to the provided data type. - - Note that some numpy dtypes use very permissive casting rules. For example, - `np.bool_({'not remotely a bool'})` returns `True`. Thus this function should not be used for - validating that the provided fill value is a valid instance of the data type. - - Parameters - ---------- - fill_value : Any - A potential fill value. - dtype : str - A valid Zarr format 3 DataType. - - Returns - ------- - A scalar instance of `dtype` - """ - data_type = DataType(dtype) - if fill_value is None: - raise ValueError("Fill value cannot be None") - if data_type == DataType.string: - return np.str_(fill_value) - if data_type == DataType.bytes: - return np.bytes_(fill_value) - - # the rest are numeric types - np_dtype = cast("np.dtype[Any]", data_type.to_numpy()) - - if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): - if data_type in (DataType.complex64, DataType.complex128): - if len(fill_value) == 2: - decoded_fill_value = tuple( - SPECIAL_FLOATS_ENCODED.get(value, value) for value in fill_value - ) - # complex datatypes serialize to JSON arrays with two elements - return np_dtype.type(complex(*decoded_fill_value)) - else: - msg = ( - f"Got an invalid fill value for complex data type {data_type.value}." - f"Expected a sequence with 2 elements, but {fill_value!r} has " - f"length {len(fill_value)}." - ) - raise ValueError(msg) - msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {data_type.value}." - raise TypeError(msg) - - # Cast the fill_value to the given dtype - try: - # This warning filter can be removed after Zarr supports numpy>=2.0 - # The warning is saying that the future behavior of out of bounds casting will be to raise - # an OverflowError. In the meantime, we allow overflow and catch cases where - # fill_value != casted_value below. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - casted_value = np.dtype(np_dtype).type(fill_value) - except (ValueError, OverflowError, TypeError) as e: - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") from e - # Check if the value is still representable by the dtype - if (fill_value == "NaN" and np.isnan(casted_value)) or ( - fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value) - ): - pass - elif np_dtype.kind == "f": - # float comparison is not exact, especially when dtype str | bytes | np.generic: - if dtype == DataType.string: - return "" - elif dtype == DataType.bytes: - return b"" - else: - np_dtype = dtype.to_numpy() - np_dtype = cast("np.dtype[Any]", np_dtype) - return np_dtype.type(0) # type: ignore[misc] - - -# For type checking -_bool = bool - - -class DataType(Enum): - bool = "bool" - int8 = "int8" - int16 = "int16" - int32 = "int32" - int64 = "int64" - uint8 = "uint8" - uint16 = "uint16" - uint32 = "uint32" - uint64 = "uint64" - float16 = "float16" - float32 = "float32" - float64 = "float64" - complex64 = "complex64" - complex128 = "complex128" - string = "string" - bytes = "bytes" - - @property - def byte_count(self) -> int | None: - data_type_byte_counts = { - DataType.bool: 1, - DataType.int8: 1, - DataType.int16: 2, - DataType.int32: 4, - DataType.int64: 8, - DataType.uint8: 1, - DataType.uint16: 2, - DataType.uint32: 4, - DataType.uint64: 8, - DataType.float16: 2, - DataType.float32: 4, - DataType.float64: 8, - DataType.complex64: 8, - DataType.complex128: 16, - } - try: - return data_type_byte_counts[self] - except KeyError: - # string and bytes have variable length - return None - - @property - def has_endianness(self) -> _bool: - return self.byte_count is not None and self.byte_count != 1 - - def to_numpy_shortname(self) -> str: - data_type_to_numpy = { - DataType.bool: "bool", - DataType.int8: "i1", - DataType.int16: "i2", - DataType.int32: "i4", - DataType.int64: "i8", - DataType.uint8: "u1", - DataType.uint16: "u2", - DataType.uint32: "u4", - DataType.uint64: "u8", - DataType.float16: "f2", - DataType.float32: "f4", - DataType.float64: "f8", - DataType.complex64: "c8", - DataType.complex128: "c16", - } - return data_type_to_numpy[self] - - def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[Any]: - # note: it is not possible to round trip DataType <-> np.dtype - # due to the fact that DataType.string and DataType.bytes both - # generally return np.dtype("O") from this function, even though - # they can originate as fixed-length types (e.g. " DataType: - if dtype.kind in "UT": - return DataType.string - elif dtype.kind == "S": - return DataType.bytes - elif not _NUMPY_SUPPORTS_VLEN_STRING and dtype.kind == "O": - # numpy < 2.0 does not support vlen string dtype - # so we fall back on object array of strings - return DataType.string - dtype_to_data_type = { - "|b1": "bool", - "bool": "bool", - "|i1": "int8", - " DataType: - if dtype is None: - return DataType[DEFAULT_DTYPE] - if isinstance(dtype, DataType): - return dtype - try: - return DataType(dtype) - except ValueError: - pass - try: - dtype = np.dtype(dtype) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid Zarr format 3 data_type: {dtype}") from e - # check that this is a valid v3 data_type - try: - data_type = DataType.from_numpy(dtype) - except KeyError as e: - raise ValueError(f"Invalid Zarr format 3 data_type: {dtype}") from e - return data_type diff --git a/src/zarr/core/strings.py b/src/zarr/core/strings.py deleted file mode 100644 index 15c5fddfee..0000000000 --- a/src/zarr/core/strings.py +++ /dev/null @@ -1,86 +0,0 @@ -"""This module contains utilities for working with string arrays across -different versions of Numpy. -""" - -from typing import Any, Union, cast -from warnings import warn - -import numpy as np - -# _STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays -# when reading data back from Zarr. -# Any valid string-like datatype should be fine for *setting* data. - -_STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] -_NUMPY_SUPPORTS_VLEN_STRING: bool - - -def cast_array( - data: np.ndarray[Any, np.dtype[Any]], -) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - raise NotImplementedError - - -try: - # this new vlen string dtype was added in NumPy 2.0 - _STRING_DTYPE = np.dtypes.StringDType() - _NUMPY_SUPPORTS_VLEN_STRING = True - - def cast_array( - data: np.ndarray[Any, np.dtype[Any]], - ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: - out = data.astype(_STRING_DTYPE, copy=False) - return cast("np.ndarray[Any, np.dtypes.StringDType]", out) - -except AttributeError: - # if not available, we fall back on an object array of strings, as in Zarr < 3 - _STRING_DTYPE = np.dtypes.ObjectDType() - _NUMPY_SUPPORTS_VLEN_STRING = False - - def cast_array( - data: np.ndarray[Any, np.dtype[Any]], - ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - out = data.astype(_STRING_DTYPE, copy=False) - return cast("np.ndarray[Any, np.dtypes.ObjectDType]", out) - - -def cast_to_string_dtype( - data: np.ndarray[Any, np.dtype[Any]], safe: bool = False -) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - """Take any data and attempt to cast to to our preferred string dtype. - - data : np.ndarray - The data to cast - - safe : bool - If True, do not issue a warning if the data is cast from object to string dtype. - - """ - if np.issubdtype(data.dtype, np.str_): - # legacy fixed-width string type (e.g. "= 2.", - stacklevel=2, - ) - return cast_array(data) - raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") diff --git a/src/zarr/dtype.py b/src/zarr/dtype.py new file mode 100644 index 0000000000..6e3789543b --- /dev/null +++ b/src/zarr/dtype.py @@ -0,0 +1,3 @@ +from zarr.core.dtype import ZDType, data_type_registry + +__all__ = ["ZDType", "data_type_registry"] diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 704db3f704..d1fe1d181c 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from zarr.core.config import BadConfigError, config +from zarr.core.dtype import data_type_registry if TYPE_CHECKING: from importlib.metadata import EntryPoint @@ -43,6 +44,7 @@ def __init__(self) -> None: def lazy_load(self) -> None: for e in self.lazy_load_list: self.register(e.load()) + self.lazy_load_list.clear() def register(self, cls: type[T]) -> None: @@ -58,12 +60,14 @@ def register(self, cls: type[T]) -> None: The registry module is responsible for managing implementations of codecs, pipelines, buffers and ndbuffers and collecting them from entrypoints. The implementation used is determined by the config. + +The registry module is also responsible for managing dtypes. """ def _collect_entrypoints() -> list[Registry[Any]]: """ - Collects codecs, pipelines, buffers and ndbuffers from entrypoints. + Collects codecs, pipelines, dtypes, buffers and ndbuffers from entrypoints. Entry points can either be single items or groups of items. Allowed syntax for entry_points.txt is e.g. @@ -86,6 +90,10 @@ def _collect_entrypoints() -> list[Registry[Any]]: __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer")) __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer")) __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer")) + + data_type_registry.lazy_load_list.extend(entry_points.select(group="zarr.data_type")) + data_type_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="data_type")) + __pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline")) __pipeline_registry.lazy_load_list.extend( entry_points.select(group="zarr", name="codec_pipeline") @@ -148,7 +156,8 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if len(codec_classes) == 1: return next(iter(codec_classes.values())) warnings.warn( - f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 + f"Codec '{key}' not configured in config. Selecting any implementation.", + stacklevel=2, ) return list(codec_classes.values())[-1] selected_codec_cls = codec_classes[config_entry] diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 0cb992a4f2..5e070b5387 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -17,6 +17,7 @@ from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding from zarr.core.common import JSON, ZarrFormat +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import MemoryStore, StoreLike @@ -49,10 +50,10 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]: | npst.unsigned_integer_dtypes(endianness="=") | npst.floating_dtypes(endianness="=") | npst.complex_number_dtypes(endianness="=") - # | npst.byte_string_dtypes(endianness="=") - # | npst.unicode_string_dtypes() - # | npst.datetime64_dtypes() - # | npst.timedelta64_dtypes() + | npst.byte_string_dtypes(endianness="=") + | npst.unicode_string_dtypes(endianness="=") + | npst.datetime64_dtypes(endianness="=") + | npst.timedelta64_dtypes(endianness="=") ) @@ -66,7 +67,7 @@ def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]: | npst.byte_string_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="=") | npst.datetime64_dtypes(endianness="=") - # | npst.timedelta64_dtypes() + | npst.timedelta64_dtypes(endianness="=") ) @@ -119,7 +120,9 @@ def clear_store(x: Store) -> Store: compressors = st.sampled_from([None, "default"]) zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([3, 2]) # We de-prioritize arrays having dim sizes 0, 1, 2 -array_shapes = npst.array_shapes(max_dims=4, min_side=3) | npst.array_shapes(max_dims=4, min_side=0) +array_shapes = npst.array_shapes(max_dims=4, min_side=3, max_side=5) | npst.array_shapes( + max_dims=4, min_side=0 +) @st.composite @@ -141,8 +144,9 @@ def array_metadata( shape = draw(array_shapes()) ndim = len(shape) chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim)) - dtype = draw(v3_dtypes()) - fill_value = draw(npst.from_dtype(dtype)) + np_dtype = draw(v3_dtypes()) + dtype = get_data_type_from_native_dtype(np_dtype) + fill_value = draw(npst.from_dtype(np_dtype)) if zarr_format == 2: return ArrayV2Metadata( shape=shape, diff --git a/tests/conftest.py b/tests/conftest.py index 30d7eec4d4..4d300a1fd4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,8 +19,12 @@ _parse_chunk_key_encoding, ) from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition -from zarr.core.common import JSON, DimensionNames, parse_dtype, parse_shapelike +from zarr.core.common import JSON, DimensionNames, parse_shapelike from zarr.core.config import config as zarr_config +from zarr.core.dtype import ( + get_data_type_from_native_dtype, +) +from zarr.core.dtype.common import HasItemSize from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import sync @@ -36,6 +40,7 @@ from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat + from zarr.core.dtype.wrapper import ZDType async def parse_store( @@ -265,7 +270,7 @@ def create_array_metadata( filters: FiltersLike = "auto", compressors: CompressorsLike = "auto", serializer: SerializerLike = "auto", - fill_value: Any | None = None, + fill_value: Any = 0, order: MemoryOrder | None = None, zarr_format: ZarrFormat, attributes: dict[str, JSON] | None = None, @@ -275,14 +280,19 @@ def create_array_metadata( """ Create array metadata """ - dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format) + dtype_parsed = get_data_type_from_native_dtype(dtype) shape_parsed = parse_shapelike(shape) chunk_key_encoding_parsed = _parse_chunk_key_encoding( chunk_key_encoding, zarr_format=zarr_format ) - + item_size = 1 + if isinstance(dtype_parsed, HasItemSize): + item_size = dtype_parsed.item_size shard_shape_parsed, chunk_shape_parsed = _auto_partition( - array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed + array_shape=shape_parsed, + shard_shape=shards, + chunk_shape=chunks, + item_size=item_size, ) if order is None: @@ -293,11 +303,11 @@ def create_array_metadata( if zarr_format == 2: filters_parsed, compressor_parsed = _parse_chunk_encoding_v2( - compressor=compressors, filters=filters, dtype=np.dtype(dtype) + compressor=compressors, filters=filters, dtype=dtype_parsed ) return ArrayV2Metadata( shape=shape_parsed, - dtype=np.dtype(dtype), + dtype=dtype_parsed, chunks=chunk_shape_parsed, order=order_parsed, dimension_separator=chunk_key_encoding_parsed.separator, @@ -398,7 +408,7 @@ def meta_from_array( filters: FiltersLike = "auto", compressors: CompressorsLike = "auto", serializer: SerializerLike = "auto", - fill_value: Any | None = None, + fill_value: Any = 0, order: MemoryOrder | None = None, zarr_format: ZarrFormat = 3, attributes: dict[str, JSON] | None = None, @@ -423,3 +433,12 @@ def meta_from_array( chunk_key_encoding=chunk_key_encoding, dimension_names=dimension_names, ) + + +def skip_object_dtype(dtype: ZDType[Any, Any]) -> None: + if dtype.dtype_cls is type(np.dtype("O")): + msg = ( + f"{dtype} uses the numpy object data type, which is not a valid target for data " + "type resolution" + ) + pytest.skip(msg) diff --git a/tests/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/package_with_entrypoint-0.1.dist-info/entry_points.txt index eee724c912..7eb0eb7c86 100644 --- a/tests/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -12,3 +12,5 @@ another_buffer = package_with_entrypoint:TestEntrypointGroup.Buffer another_ndbuffer = package_with_entrypoint:TestEntrypointGroup.NDBuffer [zarr.codec_pipeline] another_pipeline = package_with_entrypoint:TestEntrypointGroup.Pipeline +[zarr.data_type] +new_data_type = package_with_entrypoint:TestDataType \ No newline at end of file diff --git a/tests/package_with_entrypoint/__init__.py b/tests/package_with_entrypoint/__init__.py index cfbd4f23a9..e0d8a52c4d 100644 --- a/tests/package_with_entrypoint/__init__.py +++ b/tests/package_with_entrypoint/__init__.py @@ -1,5 +1,6 @@ -from collections.abc import Iterable -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Self import numpy as np import numpy.typing as npt @@ -7,8 +8,16 @@ import zarr.core.buffer from zarr.abc.codec import ArrayBytesCodec, CodecInput, CodecPipeline from zarr.codecs import BytesCodec -from zarr.core.array_spec import ArraySpec from zarr.core.buffer import Buffer, NDBuffer +from zarr.core.dtype.common import DataTypeValidationError, DTypeJSON, DTypeSpec_V2 +from zarr.core.dtype.npy.bool import Bool + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import ClassVar, Literal + + from zarr.core.array_spec import ArraySpec + from zarr.core.common import ZarrFormat class TestEntrypointCodec(ArrayBytesCodec): @@ -65,3 +74,28 @@ class NDBuffer(zarr.core.buffer.NDBuffer): class Pipeline(CodecPipeline): pass + + +class TestDataType(Bool): + """ + This is a "data type" that serializes to "test" + """ + + _zarr_v3_name: ClassVar[Literal["test"]] = "test" # type: ignore[assignment] + + @classmethod + def from_json(cls, data: DTypeJSON, *, zarr_format: Literal[2, 3]) -> Self: + if zarr_format == 2 and data == {"name": cls._zarr_v3_name, "object_codec_id": None}: + return cls() + if zarr_format == 3 and data == cls._zarr_v3_name: + return cls() + raise DataTypeValidationError( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}" + ) + + def to_json(self, zarr_format: ZarrFormat) -> str | DTypeSpec_V2: # type: ignore[override] + if zarr_format == 2: + return {"name": self._zarr_v3_name, "object_codec_id": None} + if zarr_format == 3: + return self._zarr_v3_name + raise ValueError("zarr_format must be 2 or 3") diff --git a/tests/test_array.py b/tests/test_array.py index 3fc7b3938c..28ea812967 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,4 +1,5 @@ import dataclasses +import inspect import json import math import multiprocessing as mp @@ -17,22 +18,19 @@ import zarr.api.asynchronous import zarr.api.synchronous as sync_api +from tests.conftest import skip_object_dtype from zarr import Array, AsyncArray, Group from zarr.abc.store import Store from zarr.codecs import ( BytesCodec, GzipCodec, TransposeCodec, - VLenBytesCodec, - VLenUTF8Codec, ZstdCodec, ) from zarr.core._info import ArrayInfo from zarr.core.array import ( CompressorsLike, FiltersLike, - _get_default_chunk_encoding_v2, - _get_default_chunk_encoding_v3, _parse_chunk_encoding_v2, _parse_chunk_encoding_v3, chunks_initialized, @@ -43,16 +41,29 @@ from zarr.core.chunk_grids import _auto_partition from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams from zarr.core.common import JSON, MemoryOrder, ZarrFormat +from zarr.core.dtype import get_data_type_from_native_dtype +from zarr.core.dtype.common import ENDIANNESS_STR, EndiannessStr +from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str +from zarr.core.dtype.npy.float import Float32, Float64 +from zarr.core.dtype.npy.int import Int16, UInt8 +from zarr.core.dtype.npy.string import VariableLengthUTF8 +from zarr.core.dtype.npy.structured import ( + Structured, +) +from zarr.core.dtype.npy.time import DateTime64, TimeDelta64 +from zarr.core.dtype.wrapper import ZDType from zarr.core.group import AsyncGroup from zarr.core.indexing import BasicIndexer, ceildiv -from zarr.core.metadata.v3 import ArrayV3Metadata, DataType +from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath +from .test_dtype.conftest import zdtype_examples + if TYPE_CHECKING: from zarr.core.array_spec import ArrayConfigLike -from zarr.core.metadata.v2 import ArrayV2Metadata + from zarr.core.metadata.v3 import ArrayV3Metadata @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) @@ -152,7 +163,7 @@ def test_array_name_properties_no_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: arr = zarr.create_array( - store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" + store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype=">i4" ) assert arr.path == "" assert arr.name == "/" @@ -178,34 +189,45 @@ def test_array_name_properties_with_group( assert spam.basename == "spam" +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("specifiy_fill_value", [True, False]) -@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "complex64"]) -def test_array_v3_fill_value_default( - store: MemoryStore, specifiy_fill_value: bool, dtype_str: str +@pytest.mark.parametrize( + "zdtype", zdtype_examples, ids=tuple(str(type(v)) for v in zdtype_examples) +) +def test_array_fill_value_default( + store: MemoryStore, specifiy_fill_value: bool, zdtype: ZDType[Any, Any] ) -> None: """ Test that creating an array with the fill_value parameter set to None, or unspecified, - results in the expected fill_value attribute of the array, i.e. 0 cast to the array's dtype. + results in the expected fill_value attribute of the array, i.e. the default value of the dtype """ shape = (10,) - default_fill_value = 0 if specifiy_fill_value: arr = zarr.create_array( store=store, shape=shape, - dtype=dtype_str, + dtype=zdtype, zarr_format=3, chunks=shape, fill_value=None, ) else: - arr = zarr.create_array( - store=store, shape=shape, dtype=dtype_str, zarr_format=3, chunks=shape - ) + arr = zarr.create_array(store=store, shape=shape, dtype=zdtype, zarr_format=3, chunks=shape) + expected_fill_value = zdtype.default_scalar() + if isinstance(expected_fill_value, np.datetime64 | np.timedelta64): + if np.isnat(expected_fill_value): + assert np.isnat(arr.fill_value) + elif isinstance(expected_fill_value, np.floating | np.complexfloating): + if np.isnan(expected_fill_value): + assert np.isnan(arr.fill_value) + else: + assert arr.fill_value == expected_fill_value + # A simpler check would be to ensure that arr.fill_value.dtype == arr.dtype + # But for some numpy data types (namely, U), scalars might not have length. An empty string + # scalar from a `>U4` array would have dtype `>U`, and arr.fill_value.dtype == arr.dtype will fail. - assert arr.fill_value == np.dtype(dtype_str).type(default_fill_value) - assert arr.fill_value.dtype == arr.dtype + assert type(arr.fill_value) is type(np.array([arr.fill_value], dtype=arr.dtype)[0]) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -348,7 +370,7 @@ def test_storage_transformers(store: MemoryStore, zarr_format: ZarrFormat | str) "zarr_format": zarr_format, "shape": (10,), "chunks": (1,), - "dtype": "uint8", + "dtype": "|u1", "dimension_separator": ".", "codecs": (BytesCodec().to_dict(),), "fill_value": 0, @@ -458,48 +480,6 @@ async def test_nbytes_stored_async() -> None: assert result == 902 # the size with all chunks filled. -def test_default_fill_values() -> None: - a = zarr.Array.create(MemoryStore(), shape=5, chunk_shape=5, dtype=" None: - with pytest.raises(ValueError, match="At least one ArrayBytesCodec is required."): - Array.create(MemoryStore(), shape=5, chunks=5, dtype=" None: # regression test for https://github.com/zarr-developers/zarr-python/issues/2328 @@ -521,7 +501,7 @@ def test_info_v2(self, chunks: tuple[int, int], shards: tuple[int, int] | None) result = arr.info expected = ArrayInfo( _zarr_format=2, - _data_type=np.dtype("float64"), + _data_type=arr._async_array._zdtype, _fill_value=arr.fill_value, _shape=(8, 8), _chunk_shape=chunks, @@ -539,7 +519,7 @@ def test_info_v3(self, chunks: tuple[int, int], shards: tuple[int, int] | None) result = arr.info expected = ArrayInfo( _zarr_format=3, - _data_type=DataType.parse("float64"), + _data_type=arr._async_array._zdtype, _fill_value=arr.fill_value, _shape=(8, 8), _chunk_shape=chunks, @@ -565,7 +545,7 @@ def test_info_complete(self, chunks: tuple[int, int], shards: tuple[int, int] | result = arr.info_complete() expected = ArrayInfo( _zarr_format=3, - _data_type=DataType.parse("float64"), + _data_type=arr._async_array._zdtype, _fill_value=arr.fill_value, _shape=(8, 8), _chunk_shape=chunks, @@ -601,7 +581,7 @@ async def test_info_v2_async( result = arr.info expected = ArrayInfo( _zarr_format=2, - _data_type=np.dtype("float64"), + _data_type=Float64(), _fill_value=arr.metadata.fill_value, _shape=(8, 8), _chunk_shape=(2, 2), @@ -627,7 +607,7 @@ async def test_info_v3_async( result = arr.info expected = ArrayInfo( _zarr_format=3, - _data_type=DataType.parse("float64"), + _data_type=arr._zdtype, _fill_value=arr.metadata.fill_value, _shape=(8, 8), _chunk_shape=chunks, @@ -655,7 +635,7 @@ async def test_info_complete_async( result = await arr.info_complete() expected = ArrayInfo( _zarr_format=3, - _data_type=DataType.parse("float64"), + _data_type=arr._zdtype, _fill_value=arr.metadata.fill_value, _shape=(8, 8), _chunk_shape=chunks, @@ -982,7 +962,10 @@ def test_auto_partition_auto_shards( expected_shards += (cs,) auto_shards, _ = _auto_partition( - array_shape=array_shape, chunk_shape=chunk_shape, shard_shape="auto", dtype=dtype + array_shape=array_shape, + chunk_shape=chunk_shape, + shard_shape="auto", + item_size=dtype.itemsize, ) assert auto_shards == expected_shards @@ -1017,53 +1000,81 @@ def test_chunks_and_shards(store: Store) -> None: assert arr_v2.shards is None @staticmethod - @pytest.mark.parametrize( - ("dtype", "fill_value_expected"), [(" None: + @pytest.mark.parametrize("dtype", zdtype_examples) + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + def test_default_fill_value(dtype: ZDType[Any, Any], store: Store) -> None: + """ + Test that the fill value of an array is set to the default value for the dtype object + """ a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype) - assert a.fill_value == fill_value_expected + if isinstance(dtype, DateTime64 | TimeDelta64) and np.isnat(a.fill_value): + assert np.isnat(dtype.default_scalar()) + else: + assert a.fill_value == dtype.default_scalar() @staticmethod - @pytest.mark.parametrize("dtype", ["uint8", "float32", "str"]) - @pytest.mark.parametrize("empty_value", [None, ()]) - async def test_no_filters_compressors( - store: MemoryStore, dtype: str, empty_value: object, zarr_format: ZarrFormat - ) -> None: + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("dtype", zdtype_examples) + def test_dtype_forms(dtype: ZDType[Any, Any], store: Store, zarr_format: ZarrFormat) -> None: """ - Test that the default ``filters`` and ``compressors`` are removed when ``create_array`` is invoked. + Test that the same array is produced from a ZDType instance, a numpy dtype, or a numpy string """ + skip_object_dtype(dtype) + a = zarr.create_array( + store, name="a", shape=(5,), chunks=(5,), dtype=dtype, zarr_format=zarr_format + ) - arr = await create_array( - store=store, - dtype=dtype, - shape=(10,), + b = zarr.create_array( + store, + name="b", + shape=(5,), + chunks=(5,), + dtype=dtype.to_native_dtype(), zarr_format=zarr_format, - compressors=empty_value, - filters=empty_value, ) - # Test metadata explicitly - if zarr_format == 2: - assert arr.metadata.zarr_format == 2 # guard for mypy - # v2 spec requires that filters be either a collection with at least one filter, or None - assert arr.metadata.filters is None - # Compressor is a single element in v2 metadata; the absence of a compressor is encoded - # as None - assert arr.metadata.compressor is None - - assert arr.filters == () - assert arr.compressors == () - else: - assert arr.metadata.zarr_format == 3 # guard for mypy - if dtype == "str": - assert arr.metadata.codecs == (VLenUTF8Codec(),) - assert arr.serializer == VLenUTF8Codec() + assert a.dtype == b.dtype + + # Structured dtypes do not have a numpy string representation that uniquely identifies them + if not isinstance(dtype, Structured): + if isinstance(dtype, VariableLengthUTF8): + # in numpy 2.3, StringDType().str becomes the string 'StringDType()' which numpy + # does not accept as a string representation of the dtype. + c = zarr.create_array( + store, + name="c", + shape=(5,), + chunks=(5,), + dtype=dtype.to_native_dtype().char, + zarr_format=zarr_format, + ) else: - assert arr.metadata.codecs == (BytesCodec(),) - assert arr.serializer == BytesCodec() + c = zarr.create_array( + store, + name="c", + shape=(5,), + chunks=(5,), + dtype=dtype.to_native_dtype().str, + zarr_format=zarr_format, + ) + assert a.dtype == c.dtype + + @staticmethod + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("dtype", zdtype_examples) + def test_dtype_roundtrip( + dtype: ZDType[Any, Any], store: Store, zarr_format: ZarrFormat + ) -> None: + """ + Test that creating an array, then opening it, gets the same array. + """ + skip_object_dtype(dtype) + a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype, zarr_format=zarr_format) + b = zarr.open_array(store) + assert a.dtype == b.dtype @staticmethod - @pytest.mark.parametrize("dtype", ["uint8", "float32", "str"]) + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("dtype", ["uint8", "float32", "U3", "S4", "V1"]) @pytest.mark.parametrize( "compressors", [ @@ -1134,7 +1145,10 @@ async def test_v3_chunk_encoding( compressors=compressors, ) filters_expected, _, compressors_expected = _parse_chunk_encoding_v3( - filters=filters, compressors=compressors, serializer="auto", dtype=np.dtype(dtype) + filters=filters, + compressors=compressors, + serializer="auto", + dtype=arr._zdtype, ) assert arr.filters == filters_expected assert arr.compressors == compressors_expected @@ -1271,7 +1285,7 @@ async def test_v2_chunk_encoding( filters=filters, ) filters_expected, compressor_expected = _parse_chunk_encoding_v2( - filters=filters, compressor=compressors, dtype=np.dtype(dtype) + filters=filters, compressor=compressors, dtype=get_data_type_from_native_dtype(dtype) ) assert arr.metadata.zarr_format == 2 # guard for mypy assert arr.metadata.compressor == compressor_expected @@ -1285,27 +1299,37 @@ async def test_v2_chunk_encoding( assert arr.filters == filters_expected @staticmethod - @pytest.mark.parametrize("dtype", ["uint8", "float32", "str"]) + @pytest.mark.parametrize("dtype", [UInt8(), Float32(), VariableLengthUTF8()]) + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") async def test_default_filters_compressors( - store: MemoryStore, dtype: str, zarr_format: ZarrFormat + store: MemoryStore, dtype: UInt8 | Float32 | VariableLengthUTF8, zarr_format: ZarrFormat ) -> None: """ Test that the default ``filters`` and ``compressors`` are used when ``create_array`` is invoked with ``filters`` and ``compressors`` unspecified. """ + arr = await create_array( store=store, - dtype=dtype, + dtype=dtype, # type: ignore[arg-type] shape=(10,), zarr_format=zarr_format, ) + + sig = inspect.signature(create_array) + if zarr_format == 3: - expected_filters, expected_serializer, expected_compressors = ( - _get_default_chunk_encoding_v3(np_dtype=np.dtype(dtype)) + expected_filters, expected_serializer, expected_compressors = _parse_chunk_encoding_v3( + compressors=sig.parameters["compressors"].default, + filters=sig.parameters["filters"].default, + serializer=sig.parameters["serializer"].default, + dtype=dtype, # type: ignore[arg-type] ) elif zarr_format == 2: - default_filters, default_compressors = _get_default_chunk_encoding_v2( - np_dtype=np.dtype(dtype) + default_filters, default_compressors = _parse_chunk_encoding_v2( + compressor=sig.parameters["compressors"].default, + filters=sig.parameters["filters"].default, + dtype=dtype, # type: ignore[arg-type] ) if default_filters is None: expected_filters = () @@ -1482,9 +1506,24 @@ async def test_name(store: Store, zarr_format: ZarrFormat, path: str | None) -> store=store, path=parent_path, zarr_format=zarr_format ) + @staticmethod + @pytest.mark.parametrize("endianness", ENDIANNESS_STR) + def test_default_endianness( + store: Store, zarr_format: ZarrFormat, endianness: EndiannessStr + ) -> None: + """ + Test that that endianness is correctly set when creating an array when not specifying a serializer + """ + dtype = Int16(endianness=endianness) + arr = zarr.create_array(store=store, shape=(1,), dtype=dtype, zarr_format=zarr_format) + byte_order: str = arr[:].dtype.byteorder # type: ignore[union-attr] + assert byte_order in NUMPY_ENDIANNESS_STR + assert endianness_from_numpy_str(byte_order) == endianness # type: ignore[arg-type] + @pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1)]) @pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_scalar_array(value: Any, zarr_format: ZarrFormat) -> None: arr = zarr.array(value, zarr_format=zarr_format) assert arr[...] == value diff --git a/tests/test_codecs/test_endian.py b/tests/test_codecs/test_endian.py index c0c4dd4e75..ab64afb1b8 100644 --- a/tests/test_codecs/test_endian.py +++ b/tests/test_codecs/test_endian.py @@ -11,6 +11,7 @@ from .test_codecs import _AsyncArrayProxy +@pytest.mark.filterwarnings("ignore:The endianness of the requested serializer") @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize("endian", ["big", "little"]) async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: @@ -32,6 +33,7 @@ async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: assert np.array_equal(data, readback_data) +@pytest.mark.filterwarnings("ignore:The endianness of the requested serializer") @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize("dtype_input_endian", [">u2", " None: - bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"] - data = np.array(bstrings).reshape((2, 3)) - assert data.dtype == "|S5" - - sp = StorePath(store, path="string") - a = zarr.create_array( - sp, - shape=data.shape, - chunks=data.shape, - dtype=data.dtype, - fill_value=b"", - compressors=compressor, - ) - assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy - - # should also work if input array is an object array, provided we explicitly specified - # a bytesting-like dtype when creating the Array - if as_object_array: - data = data.astype("O") - a[:, :] = data - assert np.array_equal(data, a[:, :]) - assert a.metadata.data_type == DataType.bytes - assert a.dtype == "O" - - # test round trip - b = Array.open(sp) - assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy - assert np.array_equal(data, b[:, :]) - assert b.metadata.data_type == DataType.bytes - assert a.dtype == "O" + assert b.metadata.data_type == get_data_type_from_native_dtype(data.dtype) + assert a.dtype == data.dtype diff --git a/tests/test_config.py b/tests/test_config.py index 2cbf172752..1dc6f8bf4f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ import os from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import Mock @@ -19,11 +19,13 @@ GzipCodec, ShardingCodec, ) +from zarr.core.array import create_array from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.buffer.core import Buffer from zarr.core.codec_pipeline import BatchedCodecPipeline from zarr.core.config import BadConfigError, config +from zarr.core.dtype import Int8, VariableLengthUTF8 from zarr.core.indexing import SelectorTuple from zarr.registry import ( fully_qualified_name, @@ -44,67 +46,66 @@ TestNDArrayLike, ) +if TYPE_CHECKING: + from zarr.core.dtype.wrapper import ZDType + def test_config_defaults_set() -> None: # regression test for available defaults - assert config.defaults == [ - { - "default_zarr_format": 3, - "array": { - "order": "C", - "write_empty_chunks": False, - "v2_default_compressor": { - "numeric": {"id": "zstd", "level": 0, "checksum": False}, - "string": {"id": "zstd", "level": 0, "checksum": False}, - "bytes": {"id": "zstd", "level": 0, "checksum": False}, - }, - "v2_default_filters": { - "numeric": None, - "string": [{"id": "vlen-utf8"}], - "bytes": [{"id": "vlen-bytes"}], - "raw": None, + assert ( + config.defaults + == [ + { + "default_zarr_format": 3, + "array": { + "order": "C", + "write_empty_chunks": False, + "v2_default_compressor": { + "default": {"id": "zstd", "level": 0, "checksum": False}, + "variable-length-string": {"id": "zstd", "level": 0, "checksum": False}, + }, + "v2_default_filters": { + "default": None, + "variable-length-string": [{"id": "vlen-utf8"}], + }, + "v3_default_filters": {"default": [], "variable-length-string": []}, + "v3_default_serializer": { + "default": {"name": "bytes", "configuration": {"endian": "little"}}, + "variable-length-string": {"name": "vlen-utf8"}, + }, + "v3_default_compressors": { + "default": [ + {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, + ], + "variable-length-string": [ + {"name": "zstd", "configuration": {"level": 0, "checksum": False}} + ], + }, }, - "v3_default_filters": {"numeric": [], "string": [], "bytes": []}, - "v3_default_serializer": { - "numeric": {"name": "bytes", "configuration": {"endian": "little"}}, - "string": {"name": "vlen-utf8"}, - "bytes": {"name": "vlen-bytes"}, + "async": {"concurrency": 10, "timeout": None}, + "threading": {"max_workers": None}, + "json_indent": 2, + "codec_pipeline": { + "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", + "batch_size": 1, }, - "v3_default_compressors": { - "numeric": [ - {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, - ], - "string": [ - {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, - ], - "bytes": [ - {"name": "zstd", "configuration": {"level": 0, "checksum": False}}, - ], + "codecs": { + "blosc": "zarr.codecs.blosc.BloscCodec", + "gzip": "zarr.codecs.gzip.GzipCodec", + "zstd": "zarr.codecs.zstd.ZstdCodec", + "bytes": "zarr.codecs.bytes.BytesCodec", + "endian": "zarr.codecs.bytes.BytesCodec", # compatibility with earlier versions of ZEP1 + "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", + "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", + "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", }, - }, - "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, - "json_indent": 2, - "codec_pipeline": { - "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", - "batch_size": 1, - }, - "buffer": "zarr.core.buffer.cpu.Buffer", - "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", - "codecs": { - "blosc": "zarr.codecs.blosc.BloscCodec", - "gzip": "zarr.codecs.gzip.GzipCodec", - "zstd": "zarr.codecs.zstd.ZstdCodec", - "bytes": "zarr.codecs.bytes.BytesCodec", - "endian": "zarr.codecs.bytes.BytesCodec", - "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", - "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", - "transpose": "zarr.codecs.transpose.TransposeCodec", - "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", - "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", - }, - } - ] + "buffer": "zarr.core.buffer.cpu.Buffer", + "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", + } + ] + ) assert config.get("array.order") == "C" assert config.get("async.concurrency") == 10 assert config.get("async.timeout") is None @@ -304,28 +305,29 @@ class NewCodec2(BytesCodec): get_codec_class("new_codec") -@pytest.mark.parametrize("dtype", ["int", "bytes", "str"]) -async def test_default_codecs(dtype: str) -> None: - with config.set( - { - "array.v3_default_compressors": { # test setting non-standard codecs - "numeric": [ - {"name": "gzip", "configuration": {"level": 5}}, - ], - "string": [ - {"name": "gzip", "configuration": {"level": 5}}, - ], - "bytes": [ - {"name": "gzip", "configuration": {"level": 5}}, - ], - } - } - ): - arr = await zarr.api.asynchronous.create_array( +@pytest.mark.parametrize("dtype_category", ["variable-length-string", "default"]) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") +async def test_default_codecs(dtype_category: str) -> None: + """ + Test that the default compressors are sensitive to the current setting of the config. + """ + zdtype: ZDType[Any, Any] + if dtype_category == "variable-length-string": + zdtype = VariableLengthUTF8() + else: + zdtype = Int8() + expected_compressors = (GzipCodec(),) + new_conf = { + f"array.v3_default_compressors.{dtype_category}": [ + c.to_dict() for c in expected_compressors + ] + } + with config.set(new_conf): + arr = await create_array( shape=(100,), chunks=(100,), - dtype=np.dtype(dtype), + dtype=zdtype, zarr_format=3, store=MemoryStore(), ) - assert arr.compressors == (GzipCodec(),) + assert arr.compressors == expected_compressors diff --git a/tests/test_dtype/__init__.py b/tests/test_dtype/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_dtype/conftest.py b/tests/test_dtype/conftest.py new file mode 100644 index 0000000000..0be1c60088 --- /dev/null +++ b/tests/test_dtype/conftest.py @@ -0,0 +1,71 @@ +# Generate a collection of zdtype instances for use in testing. +import warnings +from typing import Any + +import numpy as np + +from zarr.core.dtype import data_type_registry +from zarr.core.dtype.common import HasLength +from zarr.core.dtype.npy.structured import Structured +from zarr.core.dtype.npy.time import DateTime64, TimeDelta64 +from zarr.core.dtype.wrapper import ZDType + +zdtype_examples: tuple[ZDType[Any, Any], ...] = () +for wrapper_cls in data_type_registry.contents.values(): + # The Structured dtype has to be constructed with some actual fields + if wrapper_cls is Structured: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + zdtype_examples += ( + wrapper_cls.from_native_dtype(np.dtype([("a", np.float64), ("b", np.int8)])), + ) + elif issubclass(wrapper_cls, HasLength): + zdtype_examples += (wrapper_cls(length=1),) + elif issubclass(wrapper_cls, DateTime64 | TimeDelta64): + zdtype_examples += (wrapper_cls(unit="s", scale_factor=10),) + else: + zdtype_examples += (wrapper_cls(),) + + +def pytest_generate_tests(metafunc: Any) -> None: + """ + This is a pytest hook to parametrize class-scoped fixtures. + + This hook allows us to define class-scoped fixtures as class attributes and then + generate the parametrize calls for pytest. This allows the fixtures to be + reused across multiple tests within the same class. + + For example, if you had a regular pytest class like this: + + class TestClass: + @pytest.mark.parametrize("param_a", [1, 2, 3]) + def test_method(self, param_a): + ... + + Child classes inheriting from ``TestClass`` would not be able to override the ``param_a`` fixture + + this implementation of ``pytest_generate_tests`` allows you to define class-scoped fixtures as + class attributes, which allows the following to work: + + class TestExample: + param_a = [1, 2, 3] + + def test_example(self, param_a): + ... + + # this class will have its test_example method parametrized with the values of TestB.param_a + class TestB(TestExample): + param_a = [1, 2, 100, 10] + + """ + # Iterate over all the fixtures defined in the class + # and parametrize them with the values defined in the class + # This allows us to define class-scoped fixtures as class attributes + # and then generate the parametrize calls for pytest + for fixture_name in metafunc.fixturenames: + if hasattr(metafunc.cls, fixture_name): + params = getattr(metafunc.cls, fixture_name) + if len(params) == 0: + msg = f"{metafunc.cls}.{fixture_name} is empty. Please provide a non-empty sequence of values." + raise ValueError(msg) + metafunc.parametrize(fixture_name, params, scope="class") diff --git a/tests/test_dtype/test_npy/test_bool.py b/tests/test_dtype/test_npy/test_bool.py new file mode 100644 index 0000000000..010dec2e47 --- /dev/null +++ b/tests/test_dtype/test_npy/test_bool.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import numpy as np + +from tests.test_dtype.test_wrapper import BaseTestZDType +from zarr.core.dtype.npy.bool import Bool + + +class TestBool(BaseTestZDType): + test_cls = Bool + + valid_dtype = (np.dtype(np.bool_),) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype(np.uint16), + ) + valid_json_v2 = ({"name": "|b1", "object_codec_id": None},) + valid_json_v3 = ("bool",) + invalid_json_v2 = ( + "|b1", + "bool", + "|f8", + ) + invalid_json_v3 = ( + "|b1", + "|f8", + {"name": "bool", "configuration": {"endianness": "little"}}, + ) + + scalar_v2_params = ((Bool(), True), (Bool(), False)) + scalar_v3_params = ((Bool(), True), (Bool(), False)) + + cast_value_params = ( + (Bool(), "true", np.True_), + (Bool(), True, np.True_), + (Bool(), False, np.False_), + (Bool(), np.True_, np.True_), + (Bool(), np.False_, np.False_), + ) + item_size_params = (Bool(),) diff --git a/tests/test_dtype/test_npy/test_bytes.py b/tests/test_dtype/test_npy/test_bytes.py new file mode 100644 index 0000000000..b7c16f573e --- /dev/null +++ b/tests/test_dtype/test_npy/test_bytes.py @@ -0,0 +1,154 @@ +import numpy as np +import pytest + +from tests.test_dtype.test_wrapper import BaseTestZDType +from zarr.core.dtype.common import UnstableSpecificationWarning +from zarr.core.dtype.npy.bytes import NullTerminatedBytes, RawBytes, VariableLengthBytes + + +class TestNullTerminatedBytes(BaseTestZDType): + test_cls = NullTerminatedBytes + valid_dtype = (np.dtype("|S10"), np.dtype("|S4")) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("|U10"), + ) + valid_json_v2 = ( + {"name": "|S0", "object_codec_id": None}, + {"name": "|S2", "object_codec_id": None}, + {"name": "|S4", "object_codec_id": None}, + ) + valid_json_v3 = ({"name": "null_terminated_bytes", "configuration": {"length_bytes": 10}},) + invalid_json_v2 = ( + "|S", + "|U10", + "|f8", + ) + invalid_json_v3 = ( + {"name": "fixed_length_ascii", "configuration": {"length_bits": 0}}, + {"name": "numpy.fixed_length_ascii", "configuration": {"length_bits": "invalid"}}, + ) + + scalar_v2_params = ( + (NullTerminatedBytes(length=0), ""), + (NullTerminatedBytes(length=2), "YWI="), + (NullTerminatedBytes(length=4), "YWJjZA=="), + ) + scalar_v3_params = ( + (NullTerminatedBytes(length=0), ""), + (NullTerminatedBytes(length=2), "YWI="), + (NullTerminatedBytes(length=4), "YWJjZA=="), + ) + cast_value_params = ( + (NullTerminatedBytes(length=0), "", np.bytes_("")), + (NullTerminatedBytes(length=2), "ab", np.bytes_("ab")), + (NullTerminatedBytes(length=4), "abcdefg", np.bytes_("abcd")), + ) + item_size_params = ( + NullTerminatedBytes(length=0), + NullTerminatedBytes(length=4), + NullTerminatedBytes(length=10), + ) + + +class TestRawBytes(BaseTestZDType): + test_cls = RawBytes + valid_dtype = (np.dtype("|V10"),) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("|S10"), + ) + valid_json_v2 = ({"name": "|V10", "object_codec_id": None},) + valid_json_v3 = ( + {"name": "raw_bytes", "configuration": {"length_bytes": 0}}, + {"name": "raw_bytes", "configuration": {"length_bytes": 8}}, + ) + + invalid_json_v2 = ( + "|V", + "|S10", + "|f8", + ) + invalid_json_v3 = ( + {"name": "r10"}, + {"name": "r-80"}, + ) + + scalar_v2_params = ( + (RawBytes(length=0), ""), + (RawBytes(length=2), "YWI="), + (RawBytes(length=4), "YWJjZA=="), + ) + scalar_v3_params = ( + (RawBytes(length=0), ""), + (RawBytes(length=2), "YWI="), + (RawBytes(length=4), "YWJjZA=="), + ) + cast_value_params = ( + (RawBytes(length=0), b"", np.void(b"")), + (RawBytes(length=2), b"ab", np.void(b"ab")), + (RawBytes(length=4), b"abcd", np.void(b"abcd")), + ) + item_size_params = ( + RawBytes(length=0), + RawBytes(length=4), + RawBytes(length=10), + ) + + +class TestVariableLengthBytes(BaseTestZDType): + test_cls = VariableLengthBytes + valid_dtype = (np.dtype("|O"),) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("|U10"), + ) + valid_json_v2 = ({"name": "|O", "object_codec_id": "vlen-bytes"},) + valid_json_v3 = ("variable_length_bytes",) + invalid_json_v2 = ( + "|S", + "|U10", + "|f8", + ) + invalid_json_v3 = ( + {"name": "fixed_length_ascii", "configuration": {"length_bits": 0}}, + {"name": "numpy.fixed_length_ascii", "configuration": {"length_bits": "invalid"}}, + ) + + scalar_v2_params = ( + (VariableLengthBytes(), ""), + (VariableLengthBytes(), "YWI="), + (VariableLengthBytes(), "YWJjZA=="), + ) + scalar_v3_params = ( + (VariableLengthBytes(), ""), + (VariableLengthBytes(), "YWI="), + (VariableLengthBytes(), "YWJjZA=="), + ) + cast_value_params = ( + (VariableLengthBytes(), "", b""), + (VariableLengthBytes(), "ab", b"ab"), + (VariableLengthBytes(), "abcdefg", b"abcdefg"), + ) + item_size_params = ( + VariableLengthBytes(), + VariableLengthBytes(), + VariableLengthBytes(), + ) + + +@pytest.mark.parametrize( + "zdtype", [NullTerminatedBytes(length=10), RawBytes(length=10), VariableLengthBytes()] +) +def test_unstable_dtype_warning( + zdtype: NullTerminatedBytes | RawBytes | VariableLengthBytes, +) -> None: + """ + Test that we get a warning when serializing a dtype without a zarr v3 spec to json + when zarr_format is 3 + """ + with pytest.raises(UnstableSpecificationWarning): + zdtype.to_json(zarr_format=3) diff --git a/tests/test_dtype/test_npy/test_common.py b/tests/test_dtype/test_npy/test_common.py new file mode 100644 index 0000000000..d39d308112 --- /dev/null +++ b/tests/test_dtype/test_npy/test_common.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import base64 +import math +import re +import sys +from typing import TYPE_CHECKING, Any, get_args + +import numpy as np +import pytest + +from zarr.core.dtype.common import ENDIANNESS_STR, JSONFloatV2, SpecialFloatStrings +from zarr.core.dtype.npy.common import ( + NumpyEndiannessStr, + bytes_from_json, + bytes_to_json, + check_json_bool, + check_json_complex_float_v2, + check_json_complex_float_v3, + check_json_float_v2, + check_json_float_v3, + check_json_int, + check_json_str, + complex_float_to_json_v2, + complex_float_to_json_v3, + endianness_from_numpy_str, + endianness_to_numpy_str, + float_from_json_v2, + float_from_json_v3, + float_to_json_v2, + float_to_json_v3, +) + +if TYPE_CHECKING: + from zarr.core.common import JSON, ZarrFormat + + +def nan_equal(a: object, b: object) -> bool: + """ + Convenience function for equality comparison between two values ``a`` and ``b``, that might both + be NaN. Returns True if both ``a`` and ``b`` are NaN, otherwise returns a == b + """ + if math.isnan(a) and math.isnan(b): # type: ignore[arg-type] + return True + return a == b + + +json_float_v2_roundtrip_cases: tuple[tuple[JSONFloatV2, float | np.floating[Any]], ...] = ( + ("Infinity", float("inf")), + ("Infinity", np.inf), + ("-Infinity", float("-inf")), + ("-Infinity", -np.inf), + ("NaN", float("nan")), + ("NaN", np.nan), + (1.0, 1.0), +) + +json_float_v3_cases = json_float_v2_roundtrip_cases + + +@pytest.mark.parametrize( + ("data", "expected"), + [(">", "big"), ("<", "little"), ("=", sys.byteorder), ("|", None), ("err", "")], +) +def test_endianness_from_numpy_str(data: str, expected: str | None) -> None: + """ + Test that endianness_from_numpy_str correctly converts a numpy str literal to a human-readable literal value. + This test also checks that an invalid string input raises a ``ValueError`` + """ + if data in get_args(NumpyEndiannessStr): + assert endianness_from_numpy_str(data) == expected # type: ignore[arg-type] + else: + msg = f"Invalid endianness: {data!r}. Expected one of {get_args(NumpyEndiannessStr)}" + with pytest.raises(ValueError, match=re.escape(msg)): + endianness_from_numpy_str(data) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("data", "expected"), + [("big", ">"), ("little", "<"), (None, "|"), ("err", "")], +) +def test_endianness_to_numpy_str(data: str | None, expected: str) -> None: + """ + Test that endianness_to_numpy_str correctly converts a human-readable literal value to a numpy str literal. + This test also checks that an invalid string input raises a ``ValueError`` + """ + if data in ENDIANNESS_STR: + assert endianness_to_numpy_str(data) == expected # type: ignore[arg-type] + else: + msg = f"Invalid endianness: {data!r}. Expected one of {ENDIANNESS_STR}" + with pytest.raises(ValueError, match=re.escape(msg)): + endianness_to_numpy_str(data) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("data", "expected"), json_float_v2_roundtrip_cases + (("SHOULD_ERR", ""),) +) +def test_float_from_json_v2(data: JSONFloatV2 | str, expected: float | str) -> None: + """ + Test that float_from_json_v2 correctly converts a JSON string representation of a float to a float. + This test also checks that an invalid string input raises a ``ValueError`` + """ + if data != "SHOULD_ERR": + assert nan_equal(float_from_json_v2(data), expected) # type: ignore[arg-type] + else: + msg = f"could not convert string to float: {data!r}" + with pytest.raises(ValueError, match=msg): + float_from_json_v2(data) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("data", "expected"), json_float_v3_cases + (("SHOULD_ERR", ""), ("0x", "")) +) +def test_float_from_json_v3(data: JSONFloatV2 | str, expected: float | str) -> None: + """ + Test that float_from_json_v3 correctly converts a JSON string representation of a float to a float. + This test also checks that an invalid string input raises a ``ValueError`` + """ + if data == "SHOULD_ERR": + msg = ( + f"Invalid float value: {data!r}. Expected a string starting with the hex prefix" + " '0x', or one of 'NaN', 'Infinity', or '-Infinity'." + ) + with pytest.raises(ValueError, match=msg): + float_from_json_v3(data) + elif data == "0x": + msg = ( + f"Invalid hexadecimal float value: {data!r}. " + "Expected the '0x' prefix to be followed by 4, 8, or 16 numeral characters" + ) + + with pytest.raises(ValueError, match=msg): + float_from_json_v3(data) + else: + assert nan_equal(float_from_json_v3(data), expected) + + +# note the order of parameters relative to the order of the parametrized variable. +@pytest.mark.parametrize(("expected", "data"), json_float_v2_roundtrip_cases) +def test_float_to_json_v2(data: float | np.floating[Any], expected: JSONFloatV2) -> None: + """ + Test that floats are JSON-encoded properly for zarr v2 + """ + observed = float_to_json_v2(data) + assert observed == expected + + +# note the order of parameters relative to the order of the parametrized variable. +@pytest.mark.parametrize(("expected", "data"), json_float_v3_cases) +def test_float_to_json_v3(data: float | np.floating[Any], expected: JSONFloatV2) -> None: + """ + Test that floats are JSON-encoded properly for zarr v3 + """ + observed = float_to_json_v3(data) + assert observed == expected + + +def test_bytes_from_json(zarr_format: ZarrFormat) -> None: + """ + Test that a string is interpreted as base64-encoded bytes using the ascii alphabet. + This test takes zarr_format as a parameter but doesn't actually do anything with it, because at + present there is no zarr-format-specific logic in the code being tested, but such logic may + exist in the future. + """ + data = "\00" + assert bytes_from_json(data, zarr_format=zarr_format) == base64.b64decode(data.encode("ascii")) + + +def test_bytes_to_json(zarr_format: ZarrFormat) -> None: + """ + Test that bytes are encoded with base64 using the ascii alphabet. + + This test takes zarr_format as a parameter but doesn't actually do anything with it, because at + present there is no zarr-format-specific logic in the code being tested, but such logic may + exist in the future. + """ + + data = b"asdas" + assert bytes_to_json(data, zarr_format=zarr_format) == base64.b64encode(data).decode("ascii") + + +# note the order of parameters relative to the order of the parametrized variable. +@pytest.mark.parametrize(("json_expected", "float_data"), json_float_v2_roundtrip_cases) +def test_complex_to_json_v2( + float_data: float | np.floating[Any], json_expected: JSONFloatV2 +) -> None: + """ + Test that complex numbers are correctly converted to JSON in v2 format. + + This use the same test input as the float tests, but the conversion is tested + for complex numbers with real and imaginary parts equal to the float + values provided in the test cases. + """ + cplx = complex(float_data, float_data) + cplx_npy = np.complex128(cplx) + assert complex_float_to_json_v2(cplx) == (json_expected, json_expected) + assert complex_float_to_json_v2(cplx_npy) == (json_expected, json_expected) + + +# note the order of parameters relative to the order of the parametrized variable. +@pytest.mark.parametrize(("json_expected", "float_data"), json_float_v3_cases) +def test_complex_to_json_v3( + float_data: float | np.floating[Any], json_expected: JSONFloatV2 +) -> None: + """ + Test that complex numbers are correctly converted to JSON in v3 format. + + This use the same test input as the float tests, but the conversion is tested + for complex numbers with real and imaginary parts equal to the float + values provided in the test cases. + """ + cplx = complex(float_data, float_data) + cplx_npy = np.complex128(cplx) + assert complex_float_to_json_v3(cplx) == (json_expected, json_expected) + assert complex_float_to_json_v3(cplx_npy) == (json_expected, json_expected) + + +@pytest.mark.parametrize(("json_expected", "float_data"), json_float_v3_cases) +def test_complex_float_to_json( + float_data: float | np.floating[Any], json_expected: JSONFloatV2, zarr_format: ZarrFormat +) -> None: + """ + Test that complex numbers are correctly converted to JSON in v2 or v3 formats, depending + on the ``zarr_format`` keyword argument. + + This use the same test input as the float tests, but the conversion is tested + for complex numbers with real and imaginary parts equal to the float + values provided in the test cases. + """ + + cplx = complex(float_data, float_data) + cplx_npy = np.complex128(cplx) + if zarr_format == 2: + assert complex_float_to_json_v2(cplx) == (json_expected, json_expected) + assert complex_float_to_json_v2(cplx_npy) == ( + json_expected, + json_expected, + ) + elif zarr_format == 3: + assert complex_float_to_json_v3(cplx) == (json_expected, json_expected) + assert complex_float_to_json_v3(cplx_npy) == ( + json_expected, + json_expected, + ) + else: + raise ValueError("zarr_format must be 2 or 3") # pragma: no cover + + +check_json_float_cases = get_args(SpecialFloatStrings) + (1.0, 2) + + +@pytest.mark.parametrize("data", check_json_float_cases) +def test_check_json_float_v2_valid(data: JSONFloatV2 | int) -> None: + assert check_json_float_v2(data) + + +def test_check_json_float_v2_invalid() -> None: + assert not check_json_float_v2("invalid") + + +@pytest.mark.parametrize("data", check_json_float_cases) +def test_check_json_float_v3_valid(data: JSONFloatV2 | int) -> None: + assert check_json_float_v3(data) + + +def test_check_json_float_v3_invalid() -> None: + assert not check_json_float_v3("invalid") + + +check_json_complex_float_true_cases: tuple[list[JSONFloatV2], ...] = ( + [0.0, 1.0], + [0.0, 1.0], + [-1.0, "NaN"], + ["Infinity", 1.0], + ["Infinity", "NaN"], +) + +check_json_complex_float_false_cases: tuple[object, ...] = ( + 0.0, + "foo", + [0.0], + [1.0, 2.0, 3.0], + [1.0, "_infinity_"], + {"hello": 1.0}, +) + + +@pytest.mark.parametrize("data", check_json_complex_float_true_cases) +def test_check_json_complex_float_v2_true(data: JSON) -> None: + assert check_json_complex_float_v2(data) + + +@pytest.mark.parametrize("data", check_json_complex_float_false_cases) +def test_check_json_complex_float_v2_false(data: JSON) -> None: + assert not check_json_complex_float_v2(data) + + +@pytest.mark.parametrize("data", check_json_complex_float_true_cases) +def test_check_json_complex_float_v3_true(data: JSON) -> None: + assert check_json_complex_float_v3(data) + + +@pytest.mark.parametrize("data", check_json_complex_float_false_cases) +def test_check_json_complex_float_v3_false(data: JSON) -> None: + assert not check_json_complex_float_v3(data) + + +@pytest.mark.parametrize("data", check_json_complex_float_true_cases) +def test_check_json_complex_float_true(data: JSON, zarr_format: ZarrFormat) -> None: + if zarr_format == 2: + assert check_json_complex_float_v2(data) + elif zarr_format == 3: + assert check_json_complex_float_v3(data) + else: + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + +@pytest.mark.parametrize("data", check_json_complex_float_false_cases) +def test_check_json_complex_float_false(data: JSON, zarr_format: ZarrFormat) -> None: + if zarr_format == 2: + assert not check_json_complex_float_v2(data) + elif zarr_format == 3: + assert not check_json_complex_float_v3(data) + else: + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + +def test_check_json_int() -> None: + assert check_json_int(0) + assert not check_json_int(1.0) + + +def test_check_json_str() -> None: + assert check_json_str("0") + assert not check_json_str(1.0) + + +def test_check_json_bool() -> None: + assert check_json_bool(True) + assert check_json_bool(False) + assert not check_json_bool(1.0) + assert not check_json_bool("True") diff --git a/tests/test_dtype/test_npy/test_complex.py b/tests/test_dtype/test_npy/test_complex.py new file mode 100644 index 0000000000..b6a1e799eb --- /dev/null +++ b/tests/test_dtype/test_npy/test_complex.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import math + +import numpy as np + +from tests.test_dtype.test_wrapper import BaseTestZDType +from zarr.core.dtype.npy.complex import Complex64, Complex128 + + +class _BaseTestFloat(BaseTestZDType): + def scalar_equals(self, scalar1: object, scalar2: object) -> bool: + if np.isnan(scalar1) and np.isnan(scalar2): # type: ignore[call-overload] + return True + return super().scalar_equals(scalar1, scalar2) + + +class TestComplex64(_BaseTestFloat): + test_cls = Complex64 + valid_dtype = (np.dtype(">c8"), np.dtype("c8", "object_codec_id": None}, + {"name": "c16"), np.dtype("c16", "object_codec_id": None}, + {"name": " bool: + if np.isnan(scalar1) and np.isnan(scalar2): # type: ignore[call-overload] + return True + return super().scalar_equals(scalar1, scalar2) + + hex_string_params: tuple[tuple[str, float], ...] = () + + def test_hex_encoding(self, hex_string_params: tuple[str, float]) -> None: + """ + Test that hexadecimal strings can be read as NaN values + """ + hex_string, expected = hex_string_params + zdtype = self.test_cls() + observed = zdtype.from_json_scalar(hex_string, zarr_format=3) + assert self.scalar_equals(observed, expected) + + +class TestFloat16(_BaseTestFloat): + test_cls = Float16 + valid_dtype = (np.dtype(">f2"), np.dtype("f2", "object_codec_id": None}, + {"name": "f4"), np.dtype("f4", "object_codec_id": None}, + {"name": "f8"), np.dtype("f8", "object_codec_id": None}, + {"name": "i1", + "int8", + "|f8", + ) + invalid_json_v3 = ( + "|i1", + "|f8", + {"name": "int8", "configuration": {"endianness": "little"}}, + ) + + scalar_v2_params = ((Int8(), 1), (Int8(), -1)) + scalar_v3_params = ((Int8(), 1), (Int8(), -1)) + cast_value_params = ( + (Int8(), 1, np.int8(1)), + (Int8(), -1, np.int8(-1)), + ) + item_size_params = (Int8(),) + + +class TestInt16(BaseTestZDType): + test_cls = Int16 + scalar_type = np.int16 + valid_dtype = (np.dtype(">i2"), np.dtype("i2", "object_codec_id": None}, + {"name": "i4"), np.dtype("i4", "object_codec_id": None}, + {"name": "i8"), np.dtype("i8", "object_codec_id": None}, + {"name": "u2"), np.dtype("u2", "object_codec_id": None}, + {"name": "u4"), np.dtype("u4", "object_codec_id": None}, + {"name": "u8"), np.dtype("u8", "object_codec_id": None}, + {"name": "U10"), np.dtype("U10", "object_codec_id": None}, + {"name": " None: + """ + Test that we get a warning when serializing a dtype without a zarr v3 spec to json + when zarr_format is 3 + """ + with pytest.raises(UnstableSpecificationWarning): + zdtype.to_json(zarr_format=3) diff --git a/tests/test_dtype/test_npy/test_structured.py b/tests/test_dtype/test_npy/test_structured.py new file mode 100644 index 0000000000..e9c9ab11d0 --- /dev/null +++ b/tests/test_dtype/test_npy/test_structured.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from tests.test_dtype.test_wrapper import BaseTestZDType +from zarr.core.dtype import ( + Float16, + Float64, + Int32, + Int64, + Structured, +) + + +class TestStructured(BaseTestZDType): + test_cls = Structured + valid_dtype = ( + np.dtype([("field1", np.int32), ("field2", np.float64)]), + np.dtype([("field1", np.int64), ("field2", np.int32)]), + ) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("|S10"), + ) + valid_json_v2 = ( + {"name": [["field1", ">i4"], ["field2", ">f8"]], "object_codec_id": None}, + {"name": [["field1", ">i8"], ["field2", ">i4"]], "object_codec_id": None}, + ) + valid_json_v3 = ( + { + "name": "structured", + "configuration": { + "fields": [ + ["field1", "int32"], + ["field2", "float64"], + ] + }, + }, + { + "name": "structured", + "configuration": { + "fields": [ + [ + "field1", + { + "name": "numpy.datetime64", + "configuration": {"unit": "s", "scale_factor": 1}, + }, + ], + [ + "field2", + {"name": "fixed_length_utf32", "configuration": {"length_bytes": 32}}, + ], + ] + }, + }, + ) + invalid_json_v2 = ( + [("field1", "|i1"), ("field2", "|f8")], + [("field1", "|S10"), ("field2", "|f8")], + ) + invalid_json_v3 = ( + { + "name": "structured", + "configuration": { + "fields": [ + ("field1", {"name": "int32", "configuration": {"endianness": "invalid"}}), + ("field2", {"name": "float64", "configuration": {"endianness": "big"}}), + ] + }, + }, + {"name": "invalid_name"}, + ) + + scalar_v2_params = ( + (Structured(fields=(("field1", Int32()), ("field2", Float64()))), "AQAAAAAAAAAAAPA/"), + (Structured(fields=(("field1", Float16()), ("field2", Int32()))), "AQAAAAAA"), + ) + scalar_v3_params = ( + (Structured(fields=(("field1", Int32()), ("field2", Float64()))), "AQAAAAAAAAAAAPA/"), + (Structured(fields=(("field1", Int64()), ("field2", Int32()))), "AQAAAAAAAAAAAPA/"), + ) + + cast_value_params = ( + ( + Structured(fields=(("field1", Int32()), ("field2", Float64()))), + (1, 2.0), + np.array((1, 2.0), dtype=[("field1", np.int32), ("field2", np.float64)]), + ), + ( + Structured(fields=(("field1", Int64()), ("field2", Int32()))), + (3, 4.5), + np.array((3, 4.5), dtype=[("field1", np.int64), ("field2", np.int32)]), + ), + ) + + def scalar_equals(self, scalar1: Any, scalar2: Any) -> bool: + if hasattr(scalar1, "shape") and hasattr(scalar2, "shape"): + return np.array_equal(scalar1, scalar2) + return super().scalar_equals(scalar1, scalar2) + + item_size_params = ( + Structured(fields=(("field1", Int32()), ("field2", Float64()))), + Structured(fields=(("field1", Int64()), ("field2", Int32()))), + ) diff --git a/tests/test_dtype/test_npy/test_time.py b/tests/test_dtype/test_npy/test_time.py new file mode 100644 index 0000000000..e201be5cf6 --- /dev/null +++ b/tests/test_dtype/test_npy/test_time.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import re +from typing import get_args + +import numpy as np +import pytest + +from tests.test_dtype.test_wrapper import BaseTestZDType +from zarr.core.dtype.npy.common import DateTimeUnit +from zarr.core.dtype.npy.time import DateTime64, TimeDelta64, datetime_from_int + + +class _TestTimeBase(BaseTestZDType): + def json_scalar_equals(self, scalar1: object, scalar2: object) -> bool: + # This method gets overridden here to support the equivalency between NaT and + # -9223372036854775808 fill values + nat_scalars = (-9223372036854775808, "NaT") + if scalar1 in nat_scalars and scalar2 in nat_scalars: + return True + return scalar1 == scalar2 + + def scalar_equals(self, scalar1: object, scalar2: object) -> bool: + if np.isnan(scalar1) and np.isnan(scalar2): # type: ignore[call-overload] + return True + return super().scalar_equals(scalar1, scalar2) + + +class TestDateTime64(_TestTimeBase): + test_cls = DateTime64 + valid_dtype = (np.dtype("datetime64[10ns]"), np.dtype("datetime64[us]"), np.dtype("datetime64")) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("timedelta64[ns]"), + ) + valid_json_v2 = ( + {"name": ">M8", "object_codec_id": None}, + {"name": ">M8[s]", "object_codec_id": None}, + {"name": "m8", "object_codec_id": None}, + {"name": ">m8[s]", "object_codec_id": None}, + {"name": " None: + """ + Test that an invalid unit raises a ValueError. + """ + unit = "invalid" + msg = f"unit must be one of ('Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'μs', 'ns', 'ps', 'fs', 'as', 'generic'), got {unit!r}." + with pytest.raises(ValueError, match=re.escape(msg)): + DateTime64(unit=unit) # type: ignore[arg-type] + with pytest.raises(ValueError, match=re.escape(msg)): + TimeDelta64(unit=unit) # type: ignore[arg-type] + + +def test_time_scale_factor_too_low() -> None: + """ + Test that an invalid unit raises a ValueError. + """ + scale_factor = 0 + msg = f"scale_factor must be > 0, got {scale_factor}." + with pytest.raises(ValueError, match=msg): + DateTime64(scale_factor=scale_factor) + with pytest.raises(ValueError, match=msg): + TimeDelta64(scale_factor=scale_factor) + + +def test_time_scale_factor_too_high() -> None: + """ + Test that an invalid unit raises a ValueError. + """ + scale_factor = 2**31 + msg = f"scale_factor must be < 2147483648, got {scale_factor}." + with pytest.raises(ValueError, match=msg): + DateTime64(scale_factor=scale_factor) + with pytest.raises(ValueError, match=msg): + TimeDelta64(scale_factor=scale_factor) + + +@pytest.mark.parametrize("unit", get_args(DateTimeUnit)) +@pytest.mark.parametrize("scale_factor", [1, 10]) +@pytest.mark.parametrize("value", [0, 1, 10]) +def test_datetime_from_int(unit: DateTimeUnit, scale_factor: int, value: int) -> None: + """ + Test datetime_from_int. + """ + expected = np.int64(value).view(f"datetime64[{scale_factor}{unit}]") + assert datetime_from_int(value, unit=unit, scale_factor=scale_factor) == expected diff --git a/tests/test_dtype/test_wrapper.py b/tests/test_dtype/test_wrapper.py new file mode 100644 index 0000000000..8f461f1a77 --- /dev/null +++ b/tests/test_dtype/test_wrapper.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +import pytest + +from zarr.core.dtype.common import DTypeSpec_V2, DTypeSpec_V3, HasItemSize + +if TYPE_CHECKING: + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + + +""" +class _TestZDTypeSchema: + # subclasses define the URL for the schema, if available + schema_url: ClassVar[str] = "" + + @pytest.fixture(scope="class") + def get_schema(self) -> object: + response = requests.get(self.schema_url) + response.raise_for_status() + return json_schema.loads(response.text) + + def test_schema(self, schema: json_schema.Schema) -> None: + assert schema.is_valid(self.test_cls.to_json(zarr_format=2)) +""" + + +class BaseTestZDType: + """ + A base class for testing ZDType subclasses. This class works in conjunction with the custom + pytest collection function ``pytest_generate_tests`` defined in conftest.py, which applies the + following procedure when generating tests: + + At test generation time, for each test fixture referenced by a method on this class + pytest will look for an attribute with the same name as that fixture. Pytest will assume that + this class attribute is a tuple of values to be used for generating a parametrized test fixture. + + This means that child classes can, by using different values for these class attributes, have + customized test parametrization. + + Attributes + ---------- + test_cls : type[ZDType[TBaseDType, TBaseScalar]] + The ZDType subclass being tested. + scalar_type : ClassVar[type[TBaseScalar]] + The expected scalar type for the ZDType. + valid_dtype : ClassVar[tuple[TBaseDType, ...]] + A tuple of valid numpy dtypes for the ZDType. + invalid_dtype : ClassVar[tuple[TBaseDType, ...]] + A tuple of invalid numpy dtypes for the ZDType. + valid_json_v2 : ClassVar[tuple[str | dict[str, object] | list[object], ...]] + A tuple of valid JSON representations for Zarr format version 2. + invalid_json_v2 : ClassVar[tuple[str | dict[str, object] | list[object], ...]] + A tuple of invalid JSON representations for Zarr format version 2. + valid_json_v3 : ClassVar[tuple[str | dict[str, object], ...]] + A tuple of valid JSON representations for Zarr format version 3. + invalid_json_v3 : ClassVar[tuple[str | dict[str, object], ...]] + A tuple of invalid JSON representations for Zarr format version 3. + cast_value_params : ClassVar[tuple[tuple[Any, Any, Any], ...]] + A tuple of (dtype, value, expected) tuples for testing ZDType.cast_value. + """ + + test_cls: type[ZDType[TBaseDType, TBaseScalar]] + scalar_type: ClassVar[type[TBaseScalar]] + valid_dtype: ClassVar[tuple[TBaseDType, ...]] = () + invalid_dtype: ClassVar[tuple[TBaseDType, ...]] = () + + valid_json_v2: ClassVar[tuple[DTypeSpec_V2, ...]] = () + invalid_json_v2: ClassVar[tuple[str | dict[str, object] | list[object], ...]] = () + + valid_json_v3: ClassVar[tuple[DTypeSpec_V3, ...]] = () + invalid_json_v3: ClassVar[tuple[str | dict[str, object], ...]] = () + + # for testing scalar round-trip serialization, we need a tuple of (data type json, scalar json) + # pairs. the first element of the pair is used to create a dtype instance, and the second + # element is the json serialization of the scalar that we want to round-trip. + + scalar_v2_params: ClassVar[tuple[tuple[Any, Any], ...]] = () + scalar_v3_params: ClassVar[tuple[tuple[Any, Any], ...]] = () + cast_value_params: ClassVar[tuple[tuple[Any, Any, Any], ...]] + item_size_params: ClassVar[tuple[ZDType[Any, Any], ...]] + + def json_scalar_equals(self, scalar1: object, scalar2: object) -> bool: + # An equality check for json-encoded scalars. This defaults to regular equality, + # but some classes may need to override this for special cases + return scalar1 == scalar2 + + def scalar_equals(self, scalar1: object, scalar2: object) -> bool: + # An equality check for scalars. This defaults to regular equality, + # but some classes may need to override this for special cases + return scalar1 == scalar2 + + def test_check_dtype_valid(self, valid_dtype: TBaseDType) -> None: + assert self.test_cls._check_native_dtype(valid_dtype) + + def test_check_dtype_invalid(self, invalid_dtype: object) -> None: + assert not self.test_cls._check_native_dtype(invalid_dtype) # type: ignore[arg-type] + + def test_from_dtype_roundtrip(self, valid_dtype: Any) -> None: + zdtype = self.test_cls.from_native_dtype(valid_dtype) + assert zdtype.to_native_dtype() == valid_dtype + + def test_from_json_roundtrip_v2(self, valid_json_v2: DTypeSpec_V2) -> None: + zdtype = self.test_cls.from_json(valid_json_v2, zarr_format=2) + assert zdtype.to_json(zarr_format=2) == valid_json_v2 + + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + def test_from_json_roundtrip_v3(self, valid_json_v3: DTypeSpec_V3) -> None: + zdtype = self.test_cls.from_json(valid_json_v3, zarr_format=3) + assert zdtype.to_json(zarr_format=3) == valid_json_v3 + + def test_scalar_roundtrip_v2(self, scalar_v2_params: tuple[ZDType[Any, Any], Any]) -> None: + zdtype, scalar_json = scalar_v2_params + scalar = zdtype.from_json_scalar(scalar_json, zarr_format=2) + assert self.json_scalar_equals(scalar_json, zdtype.to_json_scalar(scalar, zarr_format=2)) + + def test_scalar_roundtrip_v3(self, scalar_v3_params: tuple[ZDType[Any, Any], Any]) -> None: + zdtype, scalar_json = scalar_v3_params + scalar = zdtype.from_json_scalar(scalar_json, zarr_format=3) + assert self.json_scalar_equals(scalar_json, zdtype.to_json_scalar(scalar, zarr_format=3)) + + def test_cast_value(self, cast_value_params: tuple[ZDType[Any, Any], Any, Any]) -> None: + zdtype, value, expected = cast_value_params + observed = zdtype.cast_scalar(value) + assert self.scalar_equals(expected, observed) + + def test_item_size(self, item_size_params: ZDType[Any, Any]) -> None: + """ + Test that the item_size attribute matches the numpy dtype itemsize attribute, for dtypes + with a fixed scalar size. + """ + if isinstance(item_size_params, HasItemSize): + assert item_size_params.item_size == item_size_params.to_native_dtype().itemsize + else: + pytest.skip(f"Dtype {item_size_params} does not implement HasItemSize") diff --git a/tests/test_dtype_registry.py b/tests/test_dtype_registry.py new file mode 100644 index 0000000000..c7d5f90065 --- /dev/null +++ b/tests/test_dtype_registry.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import re +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any, get_args + +import numpy as np +import pytest + +import zarr +from tests.conftest import skip_object_dtype +from zarr.core.config import config +from zarr.core.dtype import ( + AnyDType, + Bool, + DataTypeRegistry, + DateTime64, + FixedLengthUTF32, + Int8, + Int16, + TBaseDType, + TBaseScalar, + ZDType, + data_type_registry, + get_data_type_from_json, + parse_data_type, +) + +if TYPE_CHECKING: + from collections.abc import Generator + + from zarr.core.common import ZarrFormat + +from .test_dtype.conftest import zdtype_examples + + +@pytest.fixture +def data_type_registry_fixture() -> DataTypeRegistry: + return DataTypeRegistry() + + +class TestRegistry: + @staticmethod + def test_register(data_type_registry_fixture: DataTypeRegistry) -> None: + """ + Test that registering a dtype in a data type registry works. + """ + data_type_registry_fixture.register(Bool._zarr_v3_name, Bool) + assert data_type_registry_fixture.get(Bool._zarr_v3_name) == Bool + assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), Bool) + + @staticmethod + def test_override(data_type_registry_fixture: DataTypeRegistry) -> None: + """ + Test that registering a new dtype with the same name works (overriding the previous one). + """ + data_type_registry_fixture.register(Bool._zarr_v3_name, Bool) + + class NewBool(Bool): + def default_scalar(self) -> np.bool_: + return np.True_ + + data_type_registry_fixture.register(NewBool._zarr_v3_name, NewBool) + assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), NewBool) + + @staticmethod + @pytest.mark.parametrize( + ("wrapper_cls", "dtype_str"), [(Bool, "bool"), (FixedLengthUTF32, "|U4")] + ) + def test_match_dtype( + data_type_registry_fixture: DataTypeRegistry, + wrapper_cls: type[ZDType[TBaseDType, TBaseScalar]], + dtype_str: str, + ) -> None: + """ + Test that match_dtype resolves a numpy dtype into an instance of the correspond wrapper for that dtype. + """ + data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls) + assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls) + + @staticmethod + def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None: + """ + Test that match_dtype raises an error if the dtype is not registered. + """ + outside_dtype_name = "int8" + outside_dtype = np.dtype(outside_dtype_name) + msg = f"No Zarr data type found that matches dtype '{outside_dtype!r}'" + with pytest.raises(ValueError, match=re.escape(msg)): + data_type_registry_fixture.match_dtype(outside_dtype) + + with pytest.raises(KeyError): + data_type_registry_fixture.get(outside_dtype_name) + + @staticmethod + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("zdtype", zdtype_examples) + def test_registered_dtypes_match_dtype(zdtype: ZDType[TBaseDType, TBaseScalar]) -> None: + """ + Test that the registered dtypes can be retrieved from the registry. + """ + skip_object_dtype(zdtype) + assert data_type_registry.match_dtype(zdtype.to_native_dtype()) == zdtype + + @staticmethod + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("zdtype", zdtype_examples) + def test_registered_dtypes_match_json( + zdtype: ZDType[TBaseDType, TBaseScalar], zarr_format: ZarrFormat + ) -> None: + assert ( + data_type_registry.match_json( + zdtype.to_json(zarr_format=zarr_format), zarr_format=zarr_format + ) + == zdtype + ) + + @staticmethod + @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") + @pytest.mark.parametrize("zdtype", zdtype_examples) + def test_match_dtype_unique( + zdtype: ZDType[Any, Any], + data_type_registry_fixture: DataTypeRegistry, + zarr_format: ZarrFormat, + ) -> None: + """ + Test that the match_dtype method uniquely specifies a registered data type. We create a local registry + that excludes the data type class being tested, and ensure that an instance of the wrapped data type + fails to match anything in the registry + """ + skip_object_dtype(zdtype) + for _cls in get_args(AnyDType): + if _cls is not type(zdtype): + data_type_registry_fixture.register(_cls._zarr_v3_name, _cls) + + dtype_instance = zdtype.to_native_dtype() + + msg = f"No Zarr data type found that matches dtype '{dtype_instance!r}'" + with pytest.raises(ValueError, match=re.escape(msg)): + data_type_registry_fixture.match_dtype(dtype_instance) + + instance_dict = zdtype.to_json(zarr_format=zarr_format) + msg = f"No Zarr data type found that matches {instance_dict!r}" + with pytest.raises(ValueError, match=re.escape(msg)): + data_type_registry_fixture.match_json(instance_dict, zarr_format=zarr_format) + + +# this is copied from the registry tests -- we should deduplicate +here = str(Path(__file__).parent.absolute()) + + +@pytest.fixture +def set_path() -> Generator[None, None, None]: + sys.path.append(here) + zarr.registry._collect_entrypoints() + yield + sys.path.remove(here) + registries = zarr.registry._collect_entrypoints() + for registry in registries: + registry.lazy_load_list.clear() + config.reset() + + +@pytest.mark.usefixtures("set_path") +def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None: + from package_with_entrypoint import TestDataType + + data_type_registry.lazy_load() + instance = TestDataType() + dtype_json = instance.to_json(zarr_format=zarr_format) + assert get_data_type_from_json(dtype_json, zarr_format=zarr_format) == instance + data_type_registry.unregister(TestDataType._zarr_v3_name) + + +@pytest.mark.parametrize( + ("dtype_params", "expected", "zarr_format"), + [ + ("int8", Int8(), 3), + (Int8(), Int8(), 3), + (">i2", Int16(endianness="big"), 2), + ("datetime64[10s]", DateTime64(unit="s", scale_factor=10), 2), + ( + {"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}}, + DateTime64(unit="s", scale_factor=10), + 3, + ), + ], +) +def test_parse_data_type( + dtype_params: Any, expected: ZDType[Any, Any], zarr_format: ZarrFormat +) -> None: + """ + Test that parse_data_type accepts alternative representations of ZDType instances, and resolves + those inputs to the expected ZDType instance. + """ + observed = parse_data_type(dtype_params, zarr_format=zarr_format) + assert observed == expected diff --git a/tests/test_group.py b/tests/test_group.py index 7cf29c30d9..60a1fcb9bf 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,6 +23,8 @@ from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config +from zarr.core.dtype.common import unpack_dtype_json +from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( ConsolidatedMetadata, GroupMetadata, @@ -494,7 +496,7 @@ def test_group_child_iterators(store: Store, zarr_format: ZarrFormat, consolidat expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False)) fill_value = 3 - dtype = "uint8" + dtype = UInt8() expected_group_values[0].create_group("subgroup") expected_group_values[0].create_array( @@ -515,7 +517,7 @@ def test_group_child_iterators(store: Store, zarr_format: ZarrFormat, consolidat metadata = { "subarray": { "attributes": {}, - "dtype": dtype, + "dtype": unpack_dtype_json(dtype.to_json(zarr_format=zarr_format)), "fill_value": fill_value, "shape": (1,), "chunks": (1,), @@ -551,7 +553,7 @@ def test_group_child_iterators(store: Store, zarr_format: ZarrFormat, consolidat {"configuration": {"endian": "little"}, "name": "bytes"}, {"configuration": {}, "name": "zstd"}, ), - "data_type": dtype, + "data_type": unpack_dtype_json(dtype.to_json(zarr_format=zarr_format)), "fill_value": fill_value, "node_type": "array", "shape": (1,), diff --git a/tests/test_info.py b/tests/test_info.py index 04e6339092..0abaff9ae7 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -1,11 +1,11 @@ import textwrap -import numpy as np import pytest from zarr.codecs.bytes import BytesCodec from zarr.core._info import ArrayInfo, GroupInfo, human_readable_size from zarr.core.common import ZarrFormat +from zarr.core.dtype.npy.int import Int32 ZARR_FORMATS = [2, 3] @@ -53,7 +53,7 @@ def test_group_info_complete(zarr_format: ZarrFormat) -> None: def test_array_info(zarr_format: ZarrFormat) -> None: info = ArrayInfo( _zarr_format=zarr_format, - _data_type=np.dtype("int32"), + _data_type=Int32(), _fill_value=0, _shape=(100, 100), _chunk_shape=(10, 100), @@ -66,7 +66,7 @@ def test_array_info(zarr_format: ZarrFormat) -> None: assert result == textwrap.dedent(f"""\ Type : Array Zarr format : {zarr_format} - Data type : int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (100, 100) Chunk shape : (10, 100) @@ -93,7 +93,7 @@ def test_array_info_complete( ) = bytes_things info = ArrayInfo( _zarr_format=zarr_format, - _data_type=np.dtype("int32"), + _data_type=Int32(), _fill_value=0, _shape=(100, 100), _chunk_shape=(10, 100), @@ -109,7 +109,7 @@ def test_array_info_complete( assert result == textwrap.dedent(f"""\ Type : Array Zarr format : {zarr_format} - Data type : int32 + Data type : Int32(endianness='little') Fill value : 0 Shape : (100, 100) Chunk shape : (10, 100) diff --git a/tests/test_metadata/test_consolidated.py b/tests/test_metadata/test_consolidated.py index ff4fe6a780..395e036db2 100644 --- a/tests/test_metadata/test_consolidated.py +++ b/tests/test_metadata/test_consolidated.py @@ -18,6 +18,7 @@ open_consolidated, ) from zarr.core.buffer import cpu, default_buffer_prototype +from zarr.core.dtype import parse_data_type from zarr.core.group import ConsolidatedMetadata, GroupMetadata from zarr.core.metadata import ArrayV3Metadata from zarr.core.metadata.v2 import ArrayV2Metadata @@ -503,7 +504,7 @@ async def test_consolidated_metadata_backwards_compatibility( async def test_consolidated_metadata_v2(self): store = zarr.storage.MemoryStore() g = await AsyncGroup.from_store(store, attributes={"key": "root"}, zarr_format=2) - dtype = "uint8" + dtype = parse_data_type("uint8", zarr_format=2) await g.create_array(name="a", shape=(1,), attributes={"key": "a"}, dtype=dtype) g1 = await g.create_group(name="g1", attributes={"key": "g1"}) await g1.create_group(name="g2", attributes={"key": "g2"}) @@ -624,7 +625,6 @@ async def test_consolidated_metadata_encodes_special_chars( memory_store: Store, zarr_format: ZarrFormat, fill_value: float ): root = await group(store=memory_store, zarr_format=zarr_format) - _child = await root.create_group("child", attributes={"test": fill_value}) _time = await root.create_array("time", shape=(12,), dtype=np.float64, fill_value=fill_value) await zarr.api.asynchronous.consolidate_metadata(memory_store) @@ -638,18 +638,11 @@ async def test_consolidated_metadata_encodes_special_chars( "consolidated_metadata" ]["metadata"] - if np.isnan(fill_value): - expected_fill_value = "NaN" - elif np.isneginf(fill_value): - expected_fill_value = "-Infinity" - elif np.isinf(fill_value): - expected_fill_value = "Infinity" + expected_fill_value = _time._zdtype.to_json_scalar(fill_value, zarr_format=2) if zarr_format == 2: - assert root_metadata["child/.zattrs"]["test"] == expected_fill_value assert root_metadata["time/.zarray"]["fill_value"] == expected_fill_value elif zarr_format == 3: - assert root_metadata["child"]["attributes"]["test"] == expected_fill_value assert root_metadata["time"]["fill_value"] == expected_fill_value diff --git a/tests/test_metadata/test_v2.py b/tests/test_metadata/test_v2.py index 08b9cb2507..a2894529aa 100644 --- a/tests/test_metadata/test_v2.py +++ b/tests/test_metadata/test_v2.py @@ -10,6 +10,8 @@ import zarr.storage from zarr.core.buffer import cpu from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.dtype.npy.float import Float32, Float64 +from zarr.core.dtype.npy.int import Int16 from zarr.core.group import ConsolidatedMetadata, GroupMetadata from zarr.core.metadata import ArrayV2Metadata from zarr.core.metadata.v2 import parse_zarr_format @@ -19,8 +21,6 @@ from zarr.abc.codec import Codec -import numcodecs - def test_parse_zarr_format_valid() -> None: assert parse_zarr_format(2) == 2 @@ -33,8 +33,8 @@ def test_parse_zarr_format_invalid(data: Any) -> None: @pytest.mark.parametrize("attributes", [None, {"foo": "bar"}]) -@pytest.mark.parametrize("filters", [None, (numcodecs.GZip(),)]) -@pytest.mark.parametrize("compressor", [None, numcodecs.GZip()]) +@pytest.mark.parametrize("filters", [None, [{"id": "gzip", "level": 1}]]) +@pytest.mark.parametrize("compressor", [None, {"id": "gzip", "level": 1}]) @pytest.mark.parametrize("fill_value", [None, 0, 1]) @pytest.mark.parametrize("order", ["C", "F"]) @pytest.mark.parametrize("dimension_separator", [".", "/", None]) @@ -86,7 +86,7 @@ def test_filters_empty_tuple_warns() -> None: "zarr_format": 2, "shape": (1,), "chunks": (1,), - "dtype": "uint8", + "dtype": "|u1", "order": "C", "compressor": None, "filters": (), @@ -128,7 +128,7 @@ async def v2_consolidated_metadata( "chunks": [730], "compressor": None, "dtype": " None: expected = ArrayV2Metadata( attributes={"key": "value"}, shape=(8,), - dtype="float64", + dtype=Float64(), chunks=(8,), fill_value=0.0, order="C", @@ -318,12 +318,11 @@ def test_zstd_checksum() -> None: assert "checksum" not in metadata["compressor"] -@pytest.mark.parametrize( - "fill_value", [None, np.void((0, 0), np.dtype([("foo", "i4"), ("bar", "i4")]))] -) +@pytest.mark.parametrize("fill_value", [np.void((0, 0), np.dtype([("foo", "i4"), ("bar", "i4")]))]) def test_structured_dtype_fill_value_serialization(tmp_path, fill_value): + zarr_format = 2 group_path = tmp_path / "test.zarr" - root_group = zarr.open_group(group_path, mode="w", zarr_format=2) + root_group = zarr.open_group(group_path, mode="w", zarr_format=zarr_format) dtype = np.dtype([("foo", "i4"), ("bar", "i4")]) root_group.create_array( name="structured_dtype", @@ -333,11 +332,7 @@ def test_structured_dtype_fill_value_serialization(tmp_path, fill_value): fill_value=fill_value, ) - zarr.consolidate_metadata(root_group.store, zarr_format=2) + zarr.consolidate_metadata(root_group.store, zarr_format=zarr_format) root_group = zarr.open_group(group_path, mode="r") - assert ( - root_group.metadata.consolidated_metadata.to_dict()["metadata"]["structured_dtype"][ - "fill_value" - ] - == fill_value - ) + observed = root_group.metadata.consolidated_metadata.metadata["structured_dtype"].fill_value + assert observed == fill_value diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 13549b10a4..4f385afa6d 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -11,13 +11,13 @@ from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.config import config +from zarr.core.dtype import get_data_type_from_native_dtype +from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING +from zarr.core.dtype.npy.time import DateTime64 from zarr.core.group import GroupMetadata, parse_node_type from zarr.core.metadata.v3 import ( ArrayV3Metadata, - DataType, - default_fill_value, parse_dimension_names, - parse_fill_value, parse_zarr_format, ) from zarr.errors import MetadataValidationError, NodeTypeValidationError @@ -54,9 +54,20 @@ ) complex_dtypes = ("complex64", "complex128") -vlen_dtypes = ("string", "bytes") - -dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes, *vlen_dtypes) +flexible_dtypes = ("str", "bytes", "void") +if _NUMPY_SUPPORTS_VLEN_STRING: + vlen_string_dtypes = ("T",) +else: + vlen_string_dtypes = ("O",) + +dtypes = ( + *bool_dtypes, + *int_dtypes, + *float_dtypes, + *complex_dtypes, + *flexible_dtypes, + *vlen_string_dtypes, +) @pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"]) @@ -110,90 +121,19 @@ def parse_dimension_names_valid(data: Sequence[str] | None) -> None: assert parse_dimension_names(data) == data -@pytest.mark.parametrize("dtype_str", dtypes) -def test_default_fill_value(dtype_str: str) -> None: - """ - Test that parse_fill_value(None, dtype) results in the 0 value for the given dtype. - """ - dtype = DataType(dtype_str) - fill_value = default_fill_value(dtype) - if dtype == DataType.string: - assert fill_value == "" - elif dtype == DataType.bytes: - assert fill_value == b"" - else: - assert fill_value == dtype.to_numpy().type(0) - - -@pytest.mark.parametrize( - ("fill_value", "dtype_str"), - [ - (True, "bool"), - (False, "bool"), - (-8, "int8"), - (0, "int16"), - (1e10, "uint64"), - (-999, "float32"), - (1e32, "float64"), - (float("NaN"), "float64"), - (np.nan, "float64"), - (np.inf, "float64"), - (-1 * np.inf, "float64"), - (0j, "complex64"), - ], -) -def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: - """ - Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. - """ - parsed = parse_fill_value(fill_value, dtype_str) - - if np.isnan(fill_value): - assert np.isnan(parsed) - else: - assert parsed == DataType(dtype_str).to_numpy().type(fill_value) - - -@pytest.mark.parametrize("fill_value", ["not a valid value"]) -@pytest.mark.parametrize("dtype_str", [*int_dtypes, *float_dtypes, *complex_dtypes]) -def test_parse_fill_value_invalid_value(fill_value: Any, dtype_str: str) -> None: - """ - Test that parse_fill_value(fill_value, dtype) raises ValueError for invalid values. - This test excludes bool because the bool constructor takes anything. - """ - with pytest.raises(ValueError): - parse_fill_value(fill_value, dtype_str) - - -@pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1], complex(1, 1), np.complex64(0)]) +@pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1]]) @pytest.mark.parametrize("dtype_str", [*complex_dtypes]) -def test_parse_fill_value_complex(fill_value: Any, dtype_str: str) -> None: +def test_jsonify_fill_value_complex(fill_value: Any, dtype_str: str) -> None: """ Test that parse_fill_value(fill_value, dtype) correctly handles complex values represented as length-2 sequences """ - dtype = DataType(dtype_str) - if isinstance(fill_value, list): - expected = dtype.to_numpy().type(complex(*fill_value)) - else: - expected = dtype.to_numpy().type(fill_value) - assert expected == parse_fill_value(fill_value, dtype_str) - - -@pytest.mark.parametrize("fill_value", [[1.0, 0.0, 3.0], [0, 1, 3], [1]]) -@pytest.mark.parametrize("dtype_str", [*complex_dtypes]) -def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> None: - """ - Test that parse_fill_value(fill_value, dtype) correctly rejects sequences with length not - equal to 2 - """ - match = ( - f"Got an invalid fill value for complex data type {dtype_str}." - f"Expected a sequence with 2 elements, but {fill_value} has " - f"length {len(fill_value)}." - ) - with pytest.raises(ValueError, match=re.escape(match)): - parse_fill_value(fill_value=fill_value, dtype=dtype_str) + zarr_format = 3 + dtype = get_data_type_from_native_dtype(dtype_str) + expected = dtype.to_native_dtype().type(complex(*fill_value)) + observed = dtype.from_json_scalar(fill_value, zarr_format=zarr_format) + assert observed == expected + assert dtype.to_json_scalar(observed, zarr_format=zarr_format) == tuple(fill_value) @pytest.mark.parametrize("fill_value", [{"foo": 10}]) @@ -203,8 +143,9 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) raises TypeError for invalid non-sequential types. This test excludes bool because the bool constructor takes anything. """ - with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): - parse_fill_value(fill_value, dtype_str) + dtype_instance = get_data_type_from_native_dtype(dtype_str) + with pytest.raises(TypeError, match=f"Invalid type: {fill_value}"): + dtype_instance.from_json_scalar(fill_value, zarr_format=3) @pytest.mark.parametrize( @@ -223,14 +164,14 @@ def test_parse_fill_value_invalid_type_sequence(fill_value: Any, dtype_str: str) This test excludes bool because the bool constructor takes anything, and complex because complex values can be created from length-2 sequences. """ - match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype_str}" - with pytest.raises(TypeError, match=re.escape(match)): - parse_fill_value(fill_value, dtype_str) + dtype_instance = get_data_type_from_native_dtype(dtype_str) + with pytest.raises(TypeError, match=re.escape(f"Invalid type: {fill_value}")): + dtype_instance.from_json_scalar(fill_value, zarr_format=3) @pytest.mark.parametrize("chunk_grid", ["regular"]) @pytest.mark.parametrize("attributes", [None, {"foo": "bar"}]) -@pytest.mark.parametrize("codecs", [[BytesCodec()]]) +@pytest.mark.parametrize("codecs", [[BytesCodec(endian=None)]]) @pytest.mark.parametrize("fill_value", [0, 1]) @pytest.mark.parametrize("chunk_key_encoding", ["v2", "default"]) @pytest.mark.parametrize("dimension_separator", [".", "/", None]) @@ -247,7 +188,7 @@ def test_metadata_to_dict( storage_transformers: tuple[dict[str, JSON]] | None, ) -> None: shape = (1, 2, 3) - data_type = DataType.uint8 + data_type_str = "uint8" if chunk_grid == "regular": cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}} @@ -271,7 +212,7 @@ def test_metadata_to_dict( "node_type": "array", "shape": shape, "chunk_grid": cgrid, - "data_type": data_type, + "data_type": data_type_str, "chunk_key_encoding": cke, "codecs": tuple(c.to_dict() for c in codecs), "fill_value": fill_value, @@ -315,50 +256,32 @@ def test_json_indent(indent: int): assert d == json.dumps(json.loads(d), indent=indent).encode() -# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) -# @pytest.mark.parametrize("precision", ["ns", "D"]) -# async def test_datetime_metadata(fill_value: int, precision: str) -> None: -# metadata_dict = { -# "zarr_format": 3, -# "node_type": "array", -# "shape": (1,), -# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, -# "data_type": f" None: +@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) +@pytest.mark.parametrize("precision", ["ns", "D"]) +async def test_datetime_metadata(fill_value: int, precision: str) -> None: + dtype = DateTime64(unit=precision) metadata_dict = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, - "data_type": " None: metadata_dict = { @@ -368,10 +291,11 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: float) -> N "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": data_type, "chunk_key_encoding": {"name": "default", "separator": "."}, - "codecs": (), + "codecs": ({"name": "bytes"},), "fill_value": fill_value, # this is not a valid fill value for uint8 } - with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): + # multiple things can go wrong here, so we don't match on the error message. + with pytest.raises(TypeError): ArrayV3Metadata.from_dict(metadata_dict) @@ -399,17 +323,3 @@ async def test_special_float_fill_values(fill_value: str) -> None: elif fill_value == "-Infinity": assert np.isneginf(m.fill_value) assert d["fill_value"] == "-Infinity" - - -@pytest.mark.parametrize("dtype_str", dtypes) -def test_dtypes(dtype_str: str) -> None: - dt = DataType(dtype_str) - np_dtype = dt.to_numpy() - if dtype_str not in vlen_dtypes: - # we can round trip "normal" dtypes - assert dt == DataType.from_numpy(np_dtype) - assert dt.byte_count == np_dtype.itemsize - assert dt.has_endianness == (dt.byte_count > 1) - else: - # return type for vlen types may vary depending on numpy version - assert dt.byte_count is None diff --git a/tests/test_properties.py b/tests/test_properties.py index d48dfe2fef..b8d50ef0b1 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,4 +1,3 @@ -import dataclasses import json import numbers from typing import Any @@ -76,6 +75,7 @@ def deep_equal(a: Any, b: Any) -> bool: return a == b +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @given(data=st.data(), zarr_format=zarr_formats) def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) @@ -83,6 +83,7 @@ def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: assert_array_equal(nparray, zarray[:]) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @given(array=arrays()) def test_array_creates_implicit_groups(array): path = array.path @@ -102,7 +103,10 @@ def test_array_creates_implicit_groups(array): # this decorator removes timeout; not ideal but it should avoid intermittent CI failures + + @settings(deadline=None) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @given(data=st.data()) def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) @@ -118,6 +122,7 @@ def test_basic_indexing(data: st.DataObject) -> None: @given(data=st.data()) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_oindex(data: st.DataObject) -> None: # integer_array_indices can't handle 0-size dimensions. zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1))) @@ -139,6 +144,7 @@ def test_oindex(data: st.DataObject) -> None: @given(data=st.data()) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_vindex(data: st.DataObject) -> None: # integer_array_indices can't handle 0-size dimensions. zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1))) @@ -162,6 +168,7 @@ def test_vindex(data: st.DataObject) -> None: @given(store=stores, meta=array_metadata()) # type: ignore[misc] +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") async def test_roundtrip_array_metadata_from_store( store: Store, meta: ArrayV2Metadata | ArrayV3Metadata ) -> None: @@ -181,6 +188,7 @@ async def test_roundtrip_array_metadata_from_store( @given(data=st.data(), zarr_format=zarr_formats) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: int) -> None: """ Verify that JSON serialization and deserialization of metadata is lossless. @@ -209,8 +217,8 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode()) metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict) - orig = dataclasses.asdict(metadata) - rt = dataclasses.asdict(metadata_roundtripped) + orig = metadata.to_dict() + rt = metadata_roundtripped.to_dict() assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" @@ -239,6 +247,29 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in # assert_array_equal(nparray, zarray[:]) +def serialized_complex_float_is_valid( + serialized: tuple[numbers.Real | str, numbers.Real | str], +) -> bool: + """ + Validate that the serialized representation of a complex float conforms to the spec. + + The specification requires that a serialized complex float must be either: + - A JSON number, or + - One of the strings "NaN", "Infinity", or "-Infinity". + + Args: + serialized: The value produced by JSON serialization for a complex floating point number. + + Returns: + bool: True if the serialized value is valid according to the spec, False otherwise. + """ + return ( + isinstance(serialized, tuple) + and len(serialized) == 2 + and all(serialized_float_is_valid(x) for x in serialized) + ) + + def serialized_float_is_valid(serialized: numbers.Real | str) -> bool: """ Validate that the serialized representation of a float conforms to the spec. @@ -259,6 +290,7 @@ def serialized_float_is_valid(serialized: numbers.Real | str) -> bool: @given(meta=array_metadata()) # type: ignore[misc] +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> None: """ Validate that the array metadata produced by the library conforms to the relevant spec (V2 vs V3). @@ -294,11 +326,11 @@ def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> N assert asdict_dict["zarr_format"] == 3 # version-agnostic validations - if meta.dtype.kind == "f": + dtype_native = meta.dtype.to_native_dtype() + if dtype_native.kind == "f": assert serialized_float_is_valid(asdict_dict["fill_value"]) - elif meta.dtype.kind == "c": + elif dtype_native.kind == "c": # fill_value should be a two-element array [real, imag]. - assert serialized_float_is_valid(asdict_dict["fill_value"].real) - assert serialized_float_is_valid(asdict_dict["fill_value"].imag) - elif meta.dtype.kind == "M" and np.isnat(meta.fill_value): - assert asdict_dict["fill_value"] == "NaT" + assert serialized_complex_float_is_valid(asdict_dict["fill_value"]) + elif dtype_native.kind in ("M", "m") and np.isnat(meta.fill_value): + assert asdict_dict["fill_value"] == -9223372036854775808 diff --git a/tests/test_regression/__init__.py b/tests/test_regression/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_regression/scripts/__init__.py b/tests/test_regression/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_regression/scripts/v2.18.py b/tests/test_regression/scripts/v2.18.py new file mode 100644 index 0000000000..39e1c5210c --- /dev/null +++ b/tests/test_regression/scripts/v2.18.py @@ -0,0 +1,81 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "zarr==2.18", +# "numcodecs==0.15" +# ] +# /// + +import argparse + +import zarr +from zarr._storage.store import BaseStore + + +def copy_group( + *, node: zarr.hierarchy.Group, store: zarr.storage.BaseStore, path: str, overwrite: bool +) -> zarr.hierarchy.Group: + result = zarr.group(store=store, path=path, overwrite=overwrite) + result.attrs.put(node.attrs.asdict()) + for key, child in node.items(): + child_path = f"{path}/{key}" + if isinstance(child, zarr.hierarchy.Group): + copy_group(node=child, store=store, path=child_path, overwrite=overwrite) + elif isinstance(child, zarr.core.Array): + copy_array(node=child, store=store, overwrite=overwrite, path=child_path) + return result + + +def copy_array( + *, node: zarr.core.Array, store: BaseStore, path: str, overwrite: bool +) -> zarr.core.Array: + result = zarr.create( + shape=node.shape, + dtype=node.dtype, + fill_value=node.fill_value, + chunks=node.chunks, + compressor=node.compressor, + filters=node.filters, + order=node.order, + dimension_separator=node._dimension_separator, + store=store, + path=path, + overwrite=overwrite, + ) + result.attrs.put(node.attrs.asdict()) + result[:] = node[:] + return result + + +def copy_node( + node: zarr.hierarchy.Group | zarr.core.Array, store: BaseStore, path: str, overwrite: bool +) -> zarr.hierarchy.Group | zarr.core.Array: + if isinstance(node, zarr.hierarchy.Group): + return copy_group(node=node, store=store, path=path, overwrite=overwrite) + elif isinstance(node, zarr.core.Array): + return copy_array(node=node, store=store, path=path, overwrite=overwrite) + else: + raise TypeError(f"Unexpected node type: {type(node)}") # pragma: no cover + + +def cli() -> None: + parser = argparse.ArgumentParser( + description="Copy a zarr hierarchy from one location to another" + ) + parser.add_argument("source", type=str, help="Path to the source zarr hierarchy") + parser.add_argument("destination", type=str, help="Path to the destination zarr hierarchy") + args = parser.parse_args() + + src, dst = args.source, args.destination + root_src = zarr.open(src, mode="r") + result = copy_node(node=root_src, store=zarr.NestedDirectoryStore(dst), path="", overwrite=True) + + print(f"successfully created {result} at {dst}") + + +def main() -> None: + cli() + + +if __name__ == "__main__": + main() diff --git a/tests/test_regression/test_regression.py b/tests/test_regression/test_regression.py new file mode 100644 index 0000000000..34c48a6933 --- /dev/null +++ b/tests/test_regression/test_regression.py @@ -0,0 +1,156 @@ +import subprocess +from dataclasses import dataclass +from itertools import product +from pathlib import Path +from typing import TYPE_CHECKING + +import numcodecs +import numpy as np +import pytest +from numcodecs import LZ4, LZMA, Blosc, GZip, VLenBytes, VLenUTF8, Zstd + +import zarr +from zarr.core.array import Array +from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding +from zarr.core.dtype.npy.bytes import VariableLengthBytes +from zarr.core.dtype.npy.string import VariableLengthUTF8 +from zarr.storage import LocalStore + +if TYPE_CHECKING: + from zarr.core.dtype import ZDTypeLike + + +def runner_installed() -> bool: + """ + Check if a PEP-723 compliant python script runner is installed. + """ + try: + subprocess.check_output(["uv", "--version"]) + return True # noqa: TRY300 + except FileNotFoundError: + return False + + +@dataclass(kw_only=True) +class ArrayParams: + values: np.ndarray[tuple[int], np.dtype[np.generic]] + fill_value: np.generic | str | int | bytes + filters: tuple[numcodecs.abc.Codec, ...] = () + compressor: numcodecs.abc.Codec + + +basic_codecs = GZip(), Blosc(), LZ4(), LZMA(), Zstd() +basic_dtypes = "|b", ">i2", ">i4", ">f4", ">f8", "c8", "c16", "M8[10us]", "m8[4ps]" +string_dtypes = "U4" +bytes_dtypes = ">S1", "V10", " Array: + dest = tmp_path / "in" + store = LocalStore(dest) + array_params: ArrayParams = request.param + compressor = array_params.compressor + chunk_key_encoding = V2ChunkKeyEncoding(separator="/") + dtype: ZDTypeLike + if array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenUTF8(),): + dtype = VariableLengthUTF8() # type: ignore[assignment] + elif array_params.values.dtype == np.dtype("|O") and array_params.filters == (VLenBytes(),): + dtype = VariableLengthBytes() + else: + dtype = array_params.values.dtype + z = zarr.create_array( + store, + shape=array_params.values.shape, + dtype=dtype, + chunks=array_params.values.shape, + compressors=compressor, + filters=array_params.filters, + fill_value=array_params.fill_value, + order="C", + chunk_key_encoding=chunk_key_encoding, + write_data=True, + zarr_format=2, + ) + z[:] = array_params.values + return z + + +# TODO: make this dynamic based on the installed scripts +script_paths = [Path(__file__).resolve().parent / "scripts" / "v2.18.py"] + + +@pytest.mark.skipif(not runner_installed(), reason="no python script runner installed") +@pytest.mark.parametrize( + "source_array", array_cases, indirect=True, ids=tuple(map(str, array_cases)) +) +@pytest.mark.parametrize("script_path", script_paths) +def test_roundtrip(source_array: Array, tmp_path: Path, script_path: Path) -> None: + out_path = tmp_path / "out" + copy_op = subprocess.run( + [ + "uv", + "run", + script_path, + str(source_array.store).removeprefix("file://"), + str(out_path), + ], + capture_output=True, + text=True, + ) + assert copy_op.returncode == 0 + out_array = zarr.open_array(store=out_path, mode="r", zarr_format=2) + assert source_array.metadata.to_dict() == out_array.metadata.to_dict() + assert np.array_equal(source_array[:], out_array[:]) diff --git a/tests/test_store/test_stateful.py b/tests/test_store/test_stateful.py index a17d7a55be..c0997c3df3 100644 --- a/tests/test_store/test_stateful.py +++ b/tests/test_store/test_stateful.py @@ -15,6 +15,7 @@ ] +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") def test_zarr_hierarchy(sync_store: Store): def mk_test_instance_sync() -> ZarrHierarchyStateMachine: return ZarrHierarchyStateMachine(sync_store) diff --git a/tests/test_strings.py b/tests/test_strings.py deleted file mode 100644 index dca0570a25..0000000000 --- a/tests/test_strings.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Tests for the strings module.""" - -import numpy as np -import pytest - -from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING, _STRING_DTYPE, cast_to_string_dtype - - -def test_string_defaults() -> None: - if _NUMPY_SUPPORTS_VLEN_STRING: - assert _STRING_DTYPE == np.dtypes.StringDType() - else: - assert _STRING_DTYPE == np.dtypes.ObjectDType() - - -def test_cast_to_string_dtype() -> None: - d1 = np.array(["a", "b", "c"]) - assert d1.dtype == np.dtype(" None: assert np.array_equal(data, a[:, :]) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -@pytest.mark.parametrize( - ("dtype", "fill_value"), - [ - ("bool", False), - ("int64", 0), - ("float64", 0.0), - ("|S1", b""), - ("|U1", ""), - ("object", ""), - (str, ""), - ], -) -def test_implicit_fill_value(store: MemoryStore, dtype: str, fill_value: Any) -> None: - arr = zarr.create(store=store, shape=(4,), fill_value=None, zarr_format=2, dtype=dtype) - assert arr.metadata.fill_value is None - assert arr.metadata.to_dict()["fill_value"] is None - result = arr[:] - if dtype is str: - # special case - numpy_dtype = np.dtype(object) - else: - numpy_dtype = np.dtype(dtype) - expected = np.full(arr.shape, fill_value, dtype=numpy_dtype) - np.testing.assert_array_equal(result, expected) - - def test_codec_pipeline() -> None: # https://github.com/zarr-developers/zarr-python/issues/2243 store = MemoryStore() @@ -86,14 +61,14 @@ def test_codec_pipeline() -> None: @pytest.mark.parametrize( - ("dtype", "expected_dtype", "fill_value", "fill_value_encoding"), + ("dtype", "expected_dtype", "fill_value", "fill_value_json"), [ - ("|S", "|S0", b"X", "WA=="), - ("|V", "|V0", b"X", "WA=="), + ("|S1", "|S1", b"X", "WA=="), + ("|V1", "|V1", b"X", "WA=="), ("|V10", "|V10", b"X", "WAAAAAAAAAAAAA=="), ], ) -async def test_v2_encode_decode(dtype, expected_dtype, fill_value, fill_value_encoding) -> None: +async def test_v2_encode_decode(dtype, expected_dtype, fill_value, fill_value_json) -> None: with config.set( { "array.v2_default_filters.bytes": [{"id": "vlen-bytes"}], @@ -114,8 +89,8 @@ async def test_v2_encode_decode(dtype, expected_dtype, fill_value, fill_value_en "chunks": [3], "compressor": None, "dtype": expected_dtype, - "fill_value": fill_value_encoding, - "filters": [{"id": "vlen-bytes"}] if dtype == "|S" else None, + "fill_value": fill_value_json, + "filters": None, "order": "C", "shape": [3], "zarr_format": 2, @@ -128,37 +103,24 @@ async def test_v2_encode_decode(dtype, expected_dtype, fill_value, fill_value_en np.testing.assert_equal(data, expected) -@pytest.mark.parametrize("dtype_value", [["|S", b"Y"], ["|U", "Y"], ["O", b"Y"]]) -def test_v2_encode_decode_with_data(dtype_value): - dtype, value = dtype_value - with config.set( - { - "array.v2_default_filters": { - "string": [{"id": "vlen-utf8"}], - "bytes": [{"id": "vlen-bytes"}], - }, - } - ): - expected = np.full((3,), value, dtype=dtype) - a = zarr.create( - shape=(3,), - zarr_format=2, - dtype=dtype, - ) - a[:] = expected - data = a[:] - np.testing.assert_equal(data, expected) - - -@pytest.mark.parametrize("dtype", [str, "str"]) -async def test_create_dtype_str(dtype: Any) -> None: - arr = zarr.create(shape=3, dtype=dtype, zarr_format=2) - assert arr.dtype.kind == "O" - assert arr.metadata.to_dict()["dtype"] == "|O" - assert arr.metadata.filters == (numcodecs.vlen.VLenBytes(),) - arr[:] = [b"a", b"bb", b"ccc"] - result = arr[:] - np.testing.assert_array_equal(result, np.array([b"a", b"bb", b"ccc"], dtype="object")) +@pytest.mark.parametrize( + ("dtype", "value"), + [ + (NullTerminatedBytes(length=1), b"Y"), + (FixedLengthUTF32(length=1), "Y"), + (VariableLengthUTF8(), "Y"), + ], +) +def test_v2_encode_decode_with_data(dtype: ZDType[Any, Any], value: str): + expected = np.full((3,), value, dtype=dtype.to_native_dtype()) + a = zarr.create( + shape=(3,), + zarr_format=2, + dtype=dtype, + ) + a[:] = expected + data = a[:] + np.testing.assert_equal(data, expected) @pytest.mark.parametrize("filters", [[], [numcodecs.Delta(dtype=" None: - with config.set( - { - "array.v2_default_compressor": { - "numeric": {"id": "zstd", "level": "0"}, - "string": {"id": "zstd", "level": "0"}, - "bytes": {"id": "zstd", "level": "0"}, - }, - "array.v2_default_filters": { - "numeric": [], - "string": [{"id": "vlen-utf8"}], - "bytes": [{"id": "vlen-bytes"}], - }, - } - ): - dtype, expected_compressor, expected_filter = dtype_expected - arr = zarr.create(shape=(3,), path="foo", store={}, zarr_format=2, dtype=dtype) - assert arr.metadata.compressor.codec_id == expected_compressor - if expected_filter is not None: - assert arr.metadata.filters[0].codec_id == expected_filter - - @pytest.mark.parametrize("fill_value", [None, (b"", 0, 0.0)], ids=["no_fill", "fill"]) def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None: a = np.array( @@ -339,35 +269,18 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None: np.dtype([("x", "i4"), ("y", "i4")]), np.array([(1, 2)], dtype=[("x", "i4"), ("y", "i4")])[0], ), - ( - "BQAAAA==", - np.dtype([("val", "i4")]), - np.array([(5,)], dtype=[("val", "i4")])[0], - ), - ( - {"x": 1, "y": 2}, - np.dtype([("location", "O")]), - np.array([({"x": 1, "y": 2},)], dtype=[("location", "O")])[0], - ), - ( - {"x": 1, "y": 2, "z": 3}, - np.dtype([("location", "O")]), - np.array([({"x": 1, "y": 2, "z": 3},)], dtype=[("location", "O")])[0], - ), ], ids=[ "tuple_input", "list_input", "bytes_input", - "string_input", - "dictionary_input", - "dictionary_input_extra_fields", ], ) def test_parse_structured_fill_value_valid( fill_value: Any, dtype: np.dtype[Any], expected_result: Any ) -> None: - result = _parse_structured_fill_value(fill_value, dtype) + zdtype = Structured.from_native_dtype(dtype) + result = zdtype.cast_scalar(fill_value) assert result.dtype == expected_result.dtype assert result == expected_result if isinstance(expected_result, np.void): @@ -375,31 +288,6 @@ def test_parse_structured_fill_value_valid( assert result[name] == expected_result[name] -@pytest.mark.parametrize( - ( - "fill_value", - "dtype", - ), - [ - (("Alice", 30), np.dtype([("name", "U10"), ("age", "i4"), ("city", "U20")])), - (b"\x01\x00\x00\x00", np.dtype([("x", "i4"), ("y", "i4")])), - ("this_is_not_base64", np.dtype([("val", "i4")])), - ("hello", np.dtype([("age", "i4")])), - ({"x": 1, "y": 2}, np.dtype([("location", "i4")])), - ], - ids=[ - "tuple_list_wrong_length", - "bytes_wrong_length", - "invalid_base64", - "wrong_data_type", - "wrong_dictionary", - ], -) -def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None: - with pytest.raises(ValueError): - _parse_structured_fill_value(fill_value, dtype) - - @pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"]) def test_other_dtype_roundtrip(fill_value, tmp_path) -> None: a = np.array([b"a\0\0", b"bb", b"ccc"], dtype="V7")