diff --git a/changes/3400.feature.rst b/changes/3400.feature.rst new file mode 100644 index 0000000000..ed7ddd701a --- /dev/null +++ b/changes/3400.feature.rst @@ -0,0 +1,3 @@ +Add a runtime type checker for ``JSON`` types, and a variety of typeddict classes necessary for +modelling Zarr metadata documents. This increases the type-safety of our internal metadata routines, +and provides Zarr users with types they can use to model Zarr metadata. \ No newline at end of file diff --git a/examples/custom_dtype.py b/examples/custom_dtype.py index a98f3414f6..9e87a1d66a 100644 --- a/examples/custom_dtype.py +++ b/examples/custom_dtype.py @@ -28,8 +28,8 @@ DataTypeValidationError, DTypeConfig_V2, DTypeJSON, - check_dtype_spec_v2, ) +from zarr.core.type_check import guard_type # This is the int2 array data type int2_dtype_cls = type(np.dtype("int2")) @@ -67,7 +67,7 @@ def to_native_dtype(self: Self) -> int2_dtype_cls: return self.dtype_cls() @classmethod - def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]: + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["int2"], None]]: """ Type check for Zarr v2-flavored JSON. @@ -84,9 +84,7 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["|b See the Zarr docs for more information about the JSON encoding for data types. """ - return ( - check_dtype_spec_v2(data) and data["name"] == "int2" and data["object_codec_id"] is None - ) + return guard_type(data, DTypeConfig_V2[Literal["int2"], None]) @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["int2"]]: diff --git a/pyproject.toml b/pyproject.toml index bea8d77127..7589c7da6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ 'packaging>=22.0', 'numpy>=1.26', 'numcodecs[crc32c]>=0.14', - 'typing_extensions>=4.9', + 'typing_extensions>=4.13', 'donfig>=0.8', ] @@ -226,7 +226,6 @@ dependencies = [ 'fsspec==2023.10.0', 's3fs==2023.10.0', 'universal_pathlib==0.0.22', - 'typing_extensions==4.9.*', 'donfig==0.8.*', 'obstore==0.5.*', # test deps diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..6a9c820f78 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,14 +1,16 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar - -from typing_extensions import ReadOnly, TypedDict +from typing import TYPE_CHECKING, Generic, TypeVar from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map +from zarr.core.common import ( # noqa: F401 CodecJSON re-exported for backwards compatibility + CodecJSON, + CodecJSON_V2, + CodecJSON_V3, + concurrent_map, +) from zarr.core.config import config if TYPE_CHECKING: @@ -37,27 +39,6 @@ CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) -TName = TypeVar("TName", bound=str, covariant=True) - - -class CodecJSON_V2(TypedDict, Generic[TName]): - """The JSON representation of a codec for Zarr V2""" - - id: ReadOnly[TName] - - -def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]: - return isinstance(data, Mapping) and "id" in data and isinstance(data["id"], str) - - -CodecJSON_V3 = str | NamedConfig[str, Mapping[str, object]] -"""The JSON representation of a codec for Zarr V3.""" - -# The widest type we will *accept* for a codec JSON -# This covers v2 and v3 -CodecJSON = str | Mapping[str, object] -"""The widest type of JSON-like input that could specify a codec.""" - class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 409601e474..2a7c37c2e9 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -3,7 +3,7 @@ import asyncio import dataclasses import warnings -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal import numpy as np import numpy.typing as npt @@ -37,7 +37,7 @@ GroupMetadata, create_hierarchy, ) -from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata +from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.errors import ( ArrayNotFoundError, GroupNotFoundError, @@ -353,13 +353,12 @@ async def open( try: metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format) # TODO: remove this cast when we fix typing for array metadata dicts - _metadata_dict = cast("ArrayMetadataDict", metadata_dict) # for v2, the above would already have raised an exception if not an array - zarr_format = _metadata_dict["zarr_format"] - is_v3_array = zarr_format == 3 and _metadata_dict.get("node_type") == "array" + zarr_format = metadata_dict["zarr_format"] + is_v3_array = zarr_format == 3 and metadata_dict.get("node_type") == "array" if is_v3_array or zarr_format == 2: return AsyncArray( - store_path=store_path, metadata=_metadata_dict, config=kwargs.get("config") + store_path=store_path, metadata=metadata_dict, config=kwargs.get("config") ) except (AssertionError, FileNotFoundError, NodeTypeValidationError): pass diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index ce19f99ba0..da7abb6b3c 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -53,6 +53,8 @@ ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, + ArrayMetadataJSON_V2, + ArrayMetadataJSON_V3, DimensionNames, MemoryOrder, ShapeLike, @@ -103,11 +105,8 @@ ) from zarr.core.metadata import ( ArrayMetadata, - ArrayMetadataDict, ArrayV2Metadata, - ArrayV2MetadataDict, ArrayV3Metadata, - ArrayV3MetadataDict, T_ArrayMetadata, ) from zarr.core.metadata.v2 import ( @@ -116,11 +115,12 @@ parse_compressor, parse_filters, ) -from zarr.core.metadata.v3 import parse_node_type_array from zarr.core.sync import sync +from zarr.core.type_check import check_type from zarr.errors import ( ArrayNotFoundError, MetadataValidationError, + NodeTypeValidationError, ZarrDeprecationWarning, ZarrUserWarning, ) @@ -175,25 +175,32 @@ class DefaultFillValue: DEFAULT_FILL_VALUE = DefaultFillValue() -def parse_array_metadata(data: Any) -> ArrayMetadata: +@overload +def parse_array_metadata(data: ArrayV2Metadata | ArrayMetadataJSON_V2) -> ArrayV2Metadata: ... + + +@overload +def parse_array_metadata(data: ArrayV3Metadata | ArrayMetadataJSON_V3) -> ArrayV3Metadata: ... + + +def parse_array_metadata( + data: ArrayV2Metadata | ArrayMetadataJSON_V2 | ArrayV3Metadata | ArrayMetadataJSON_V3, +) -> ArrayV2Metadata | ArrayV3Metadata: + """ + If the input is a dict representation of a Zarr metadata document, instantiate the right metadata + class from that dict. If the input is a metadata object, return it. + """ + if isinstance(data, ArrayMetadata): return data - elif isinstance(data, dict): - zarr_format = data.get("zarr_format") + else: + zarr_format = data["zarr_format"] if zarr_format == 3: - meta_out = ArrayV3Metadata.from_dict(data) - if len(meta_out.storage_transformers) > 0: - msg = ( - f"Array metadata contains storage transformers: {meta_out.storage_transformers}." - "Arrays with storage transformers are not supported in zarr-python at this time." - ) - raise ValueError(msg) - return meta_out + return ArrayV3Metadata.from_dict(data) # type: ignore[arg-type] elif zarr_format == 2: - return ArrayV2Metadata.from_dict(data) + return ArrayV2Metadata.from_dict(data) # type: ignore[arg-type] else: raise ValueError(f"Invalid zarr_format: {zarr_format}. Expected 2 or 3") - raise TypeError # pragma: no cover def create_codec_pipeline(metadata: ArrayMetadata, *, store: Store | None = None) -> CodecPipeline: @@ -213,9 +220,27 @@ def create_codec_pipeline(metadata: ArrayMetadata, *, store: Store | None = None raise TypeError # pragma: no cover +@overload +async def get_array_metadata( + store_path: StorePath, zarr_format: Literal[3] +) -> ArrayMetadataJSON_V3: ... + + +@overload +async def get_array_metadata( + store_path: StorePath, zarr_format: Literal[2] +) -> ArrayMetadataJSON_V2: ... + + +@overload +async def get_array_metadata( + store_path: StorePath, zarr_format: None +) -> ArrayMetadataJSON_V3 | ArrayMetadataJSON_V2: ... + + async def get_array_metadata( store_path: StorePath, zarr_format: ZarrFormat | None = 3 -) -> dict[str, JSON]: +) -> ArrayMetadataJSON_V3 | ArrayMetadataJSON_V2: if zarr_format == 2: zarray_bytes, zattrs_bytes = await gather( (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype), @@ -260,19 +285,25 @@ async def get_array_metadata( msg = f"Invalid value for 'zarr_format'. Expected 2, 3, or None. Got '{zarr_format}'." # type: ignore[unreachable] raise MetadataValidationError(msg) - metadata_dict: dict[str, JSON] + metadata_dict: ArrayMetadataJSON_V2 | ArrayMetadataJSON_V3 if zarr_format == 2: # V2 arrays are comprised of a .zarray and .zattrs objects assert zarray_bytes is not None metadata_dict = json.loads(zarray_bytes.to_bytes()) zattrs_dict = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} metadata_dict["attributes"] = zattrs_dict + tycheck = check_type(metadata_dict, ArrayMetadataJSON_V2) + if not tycheck.success: + msg = "The .zarray object at {store_path} is not a valid Zarr array metadata object. " + raise NodeTypeValidationError("zarray", "Zarr array metadata object", metadata_dict) else: # V3 arrays are comprised of a zarr.json object assert zarr_json_bytes is not None metadata_dict = json.loads(zarr_json_bytes.to_bytes()) - - parse_node_type_array(metadata_dict.get("node_type")) + tycheck = check_type(metadata_dict, ArrayMetadataJSON_V3) + if not tycheck.success: + msg = "The zarr.json object at {store_path} is not a valid Zarr array metadata object. " + raise NodeTypeValidationError("zarr.json", "Zarr array metadata object", metadata_dict) return metadata_dict @@ -311,7 +342,7 @@ class AsyncArray(Generic[T_ArrayMetadata]): @overload def __init__( self: AsyncArray[ArrayV2Metadata], - metadata: ArrayV2Metadata | ArrayV2MetadataDict, + metadata: ArrayV2Metadata | ArrayMetadataJSON_V2, store_path: StorePath, config: ArrayConfigLike | None = None, ) -> None: ... @@ -319,14 +350,14 @@ def __init__( @overload def __init__( self: AsyncArray[ArrayV3Metadata], - metadata: ArrayV3Metadata | ArrayV3MetadataDict, + metadata: ArrayV3Metadata | ArrayMetadataJSON_V3, store_path: StorePath, config: ArrayConfigLike | None = None, ) -> None: ... def __init__( self, - metadata: ArrayMetadata | ArrayMetadataDict, + metadata: ArrayMetadata | ArrayMetadataJSON_V2 | ArrayMetadataJSON_V3, store_path: StorePath, config: ArrayConfigLike | None = None, ) -> None: @@ -945,7 +976,7 @@ def from_dict( ValueError If the dictionary data is invalid or incompatible with either Zarr format 2 or 3 array creation. """ - metadata = parse_array_metadata(data) + metadata = parse_array_metadata(data) # type: ignore[call-overload] return cls(metadata=metadata, store_path=store_path) @classmethod @@ -978,9 +1009,7 @@ async def open( """ store_path = await make_store_path(store) metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format) - # TODO: remove this cast when we have better type hints - _metadata_dict = cast("ArrayV3MetadataDict", metadata_dict) - return cls(store_path=store_path, metadata=_metadata_dict) + return cls(store_path=store_path, metadata=metadata_dict) @property def store(self) -> Store: diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index bebccb65fc..3e9699b274 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -14,6 +14,7 @@ Final, Generic, Literal, + NotRequired, TypedDict, TypeVar, cast, @@ -47,11 +48,11 @@ ANY_ACCESS_MODE: Final = "r", "r+", "a", "w", "w-" DimensionNames = Iterable[str | None] | None -TName = TypeVar("TName", bound=str) +TName = TypeVar("TName", bound=str, covariant=True) TConfig = TypeVar("TConfig", bound=Mapping[str, object]) -class NamedConfig(TypedDict, Generic[TName, TConfig]): +class NamedRequiredConfig(TypedDict, Generic[TName, TConfig]): """ A typed dictionary representing an object with a name and configuration, where the configuration is a mapping of string keys to values, e.g. another typed dictionary or a JSON object. @@ -67,6 +68,132 @@ class NamedConfig(TypedDict, Generic[TName, TConfig]): """The configuration of the object.""" +class NamedConfig(TypedDict, Generic[TName, TConfig]): + """ + A typed dictionary representing an object with a name and configuration, where the configuration + is a mapping of string keys to values, e.g. another typed dictionary or a JSON object. + + The configuration key is not required. + + This class is generic with two type parameters: the type of the name (``TName``) and the type of + the configuration (``TConfig``). + """ + + name: ReadOnly[TName] + """The name of the object.""" + + configuration: ReadOnly[NotRequired[TConfig]] + """The configuration of the object.""" + + +class ArrayMetadataJSON_V2(TypedDict): + """ + A typed dictionary model for Zarr V2 array metadata. + """ + + zarr_format: Literal[2] + dtype: str | StructuredName_V2 + shape: Sequence[int] + chunks: Sequence[int] + dimension_separator: NotRequired[Literal[".", "/"]] + fill_value: Any + filters: Sequence[CodecJSON_V2[str]] | None + order: Literal["C", "F"] + compressor: CodecJSON_V2[str] | None + attributes: NotRequired[Mapping[str, JSON]] + + +class GroupMetadataJSON_V2(TypedDict): + """ + A typed dictionary model for Zarr V2 group metadata. + """ + + zarr_format: Literal[2] + attributes: NotRequired[Mapping[str, JSON]] + consolidated_metadata: NotRequired[ConsolidatedMetadata_JSON_V2] + + +class ArrayMetadataJSON_V3(TypedDict): + """ + A typed dictionary model for Zarr V3 array metadata. + """ + + zarr_format: Literal[3] + node_type: Literal["array"] + data_type: str | NamedConfig[str, Mapping[str, object]] + shape: Sequence[int] + chunk_grid: NamedConfig[str, Mapping[str, object]] + chunk_key_encoding: NamedConfig[str, Mapping[str, object]] + fill_value: object + codecs: Sequence[str | NamedConfig[str, Mapping[str, object]]] + attributes: NotRequired[Mapping[str, object]] + storage_transformers: NotRequired[Sequence[NamedConfig[str, Mapping[str, object]]]] + dimension_names: NotRequired[Sequence[str | None]] + + +class GroupMetadataJSON_V3(TypedDict): + """ + A typed dictionary model for Zarr V3 group metadata. + """ + + zarr_format: Literal[3] + node_type: Literal["group"] + attributes: NotRequired[Mapping[str, JSON]] + consolidated_metadata: NotRequired[ConsolidatedMetadata_JSON_V3 | None] + + +# TODO: use just 1 generic class and parametrize the type of the value type of the metadata +# I.e., ConsolidatedMetadata_JSON[ArrayMetadataJSON_V2 | GroupMetadataJSON_V2] +class ConsolidatedMetadata_JSON_V2(TypedDict): + """ + A typed dictionary model for Zarr consolidated metadata. + + This model is parameterized by the type of the metadata itself. + """ + + kind: Literal["inline"] + must_understand: Literal["false"] + metadata: Mapping[str, ArrayMetadataJSON_V2 | GroupMetadataJSON_V2] + + +class ConsolidatedMetadata_JSON_V3(TypedDict): + """ + A typed dictionary model for Zarr consolidated metadata. + + This model is parameterized by the type of the metadata itself. + """ + + kind: Literal["inline"] + must_understand: Literal["false"] + metadata: Mapping[str, ArrayMetadataJSON_V3 | GroupMetadataJSON_V3] + + +class CodecJSON_V2(TypedDict, Generic[TName]): + """The JSON representation of a codec for Zarr V2""" + + id: ReadOnly[TName] + + +CodecJSON_V3 = str | NamedConfig[str, Mapping[str, object]] +"""The JSON representation of a codec for Zarr V3.""" + +# The widest type we will *accept* for a codec JSON +# This covers v2 and v3 +CodecJSON = str | Mapping[str, object] +"""The widest type of JSON-like input that could specify a codec.""" + + +# By comparison, The JSON representation of a dtype in zarr v3 is much simpler. +# It's either a string, or a structured dict +DTypeSpec_V3 = str | NamedConfig[str, Mapping[str, object]] + +# This is the JSON representation of a structured dtype in zarr v2 +StructuredName_V2 = Sequence["str | StructuredName_V2"] + +# This models the type of the name a dtype might have in zarr v2 array metadata +DTypeName_V2 = StructuredName_V2 | str + + def product(tup: tuple[int, ...]) -> int: return functools.reduce(operator.mul, tup, 1) diff --git a/src/zarr/core/dtype/common.py b/src/zarr/core/dtype/common.py index 652b5fdbe3..39af1aa164 100644 --- a/src/zarr/core/dtype/common.py +++ b/src/zarr/core/dtype/common.py @@ -15,7 +15,7 @@ from typing_extensions import ReadOnly -from zarr.core.common import NamedConfig +from zarr.core.common import DTypeName_V2, DTypeSpec_V3, StructuredName_V2 from zarr.errors import UnstableSpecificationWarning EndiannessStr = Literal["little", "big"] @@ -46,13 +46,6 @@ # in a single dict, and pass that to the data type inference logic. # These type variables have a very wide bound because the individual zdtype # classes can perform a very specific type check. - -# This is the JSON representation of a structured dtype in zarr v2 -StructuredName_V2 = Sequence["str | StructuredName_V2"] - -# This models the type of the name a dtype might have in zarr v2 array metadata -DTypeName_V2 = StructuredName_V2 | str - TDTypeNameV2_co = TypeVar("TDTypeNameV2_co", bound=DTypeName_V2, covariant=True) TObjectCodecID_co = TypeVar("TObjectCodecID_co", bound=None | str, covariant=True) @@ -107,40 +100,6 @@ def check_dtype_name_v2(data: object) -> TypeGuard[DTypeName_V2]: return False -def check_dtype_spec_v2(data: object) -> TypeGuard[DTypeSpec_V2]: - """ - Type guard for narrowing a python object to an instance of DTypeSpec_V2 - """ - if not isinstance(data, Mapping): - return False - if set(data.keys()) != {"name", "object_codec_id"}: - return False - if not check_dtype_name_v2(data["name"]): - return False - return isinstance(data["object_codec_id"], str | None) - - -# By comparison, The JSON representation of a dtype in zarr v3 is much simpler. -# It's either a string, or a structured dict -DTypeSpec_V3 = str | NamedConfig[str, Mapping[str, object]] - - -def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: - """ - Type guard for narrowing the type of a python object to an instance of - DTypeSpec_V3, i.e either a string or a dict with a "name" field that's a string and a - "configuration" field that's a mapping with string keys. - """ - if isinstance(data, str) or ( # noqa: SIM103 - isinstance(data, Mapping) - and set(data.keys()) == {"name", "configuration"} - and isinstance(data["configuration"], Mapping) - and all(isinstance(k, str) for k in data["configuration"]) - ): - return True - return False - - def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: """ Return the array metadata form of the dtype JSON representation. For the Zarr V3 form of dtype diff --git a/src/zarr/core/dtype/npy/bool.py b/src/zarr/core/dtype/npy/bool.py index 37371cd0cd..85141af7af 100644 --- a/src/zarr/core/dtype/npy/bool.py +++ b/src/zarr/core/dtype/npy/bool.py @@ -10,9 +10,9 @@ DTypeConfig_V2, DTypeJSON, HasItemSize, - check_dtype_spec_v2, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -103,11 +103,7 @@ def _check_json_v2( ``TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]`` True if the input is a valid JSON representation, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and data["name"] == cls._zarr_v2_name - and data["object_codec_id"] is None - ) + return guard_type(data, DTypeConfig_V2[Literal["|b1"], None]) @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["bool"]]: @@ -173,7 +169,7 @@ def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self: """ if cls._check_json_v3(data): return cls() - msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {{'name': '|b1', 'object_codec_id': None}}" raise DataTypeValidationError(msg) @overload diff --git a/src/zarr/core/dtype/npy/bytes.py b/src/zarr/core/dtype/npy/bytes.py index b7c764dcd9..3c2b590390 100644 --- a/src/zarr/core/dtype/npy/bytes.py +++ b/src/zarr/core/dtype/npy/bytes.py @@ -15,11 +15,11 @@ HasItemSize, HasLength, HasObjectCodec, - check_dtype_spec_v2, v3_unstable_dtype_warning, ) from zarr.core.dtype.npy.common import check_json_str from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import check_type, guard_type BytesLike = np.bytes_ | str | bytes | int @@ -46,28 +46,28 @@ class FixedLengthBytesConfig(TypedDict): length_bytes: int -class NullterminatedBytesJSON_V2(DTypeConfig_V2[str, None]): - """ - A wrapper around the JSON representation of the ``NullTerminatedBytes`` data type in Zarr V2. +NullterminatedBytesJSON_V2 = DTypeConfig_V2[str, None] +""" +A wrapper around the JSON representation of the ``NullTerminatedBytes`` data type in Zarr V2. - The ``name`` field of this class contains the value that would appear under the - ``dtype`` field in Zarr V2 array metadata. +The ``name`` field of this class contains the value that would appear under the +``dtype`` field in Zarr V2 array metadata. - References - ---------- - The structure of the ``name`` field is defined in the Zarr V2 - `specification document `__. +References +---------- +The structure of the ``name`` field is defined in the Zarr V2 +`specification document `__. - Examples - -------- - .. code-block:: python +Examples +-------- +.. code-block:: python - { - "name": "|S10", - "object_codec_id": None - } - """ + { + "name": "|S10", + "object_codec_id": None + } +""" class NullTerminatedBytesJSON_V3( @@ -262,12 +262,9 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[NullterminatedBytesJSON_V2 bool True if the input data is a valid representation, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and isinstance(data["name"], str) + guard_type(data, NullterminatedBytesJSON_V2) and re.match(r"^\|S\d+$", data["name"]) is not None - and data["object_codec_id"] is None ) @classmethod @@ -286,14 +283,7 @@ def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[NullTerminatedBytesJSON_V3 True if the input is a valid representation of this class in Zarr V3, False otherwise. """ - return ( - isinstance(data, dict) - and set(data.keys()) == {"name", "configuration"} - and data["name"] == cls._zarr_v3_name - and isinstance(data["configuration"], dict) - and "length_bytes" in data["configuration"] - and isinstance(data["configuration"]["length_bytes"], int) - ) + return check_type(data, NullTerminatedBytesJSON_V3).success @classmethod def _from_json_v2(cls, data: DTypeJSON) -> Self: @@ -665,12 +655,7 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[RawBytesJSON_V2]: True if the input is a valid representation of this class in Zarr V3, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and isinstance(data["name"], str) - and re.match(r"^\|V\d+$", data["name"]) is not None - and data["object_codec_id"] is None - ) + return guard_type(data, RawBytesJSON_V2) and re.match(r"^\|V\d+$", data["name"]) is not None @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[RawBytesJSON_V3]: @@ -1019,13 +1004,7 @@ def _check_json_v2( otherwise. """ # Check that the input is a valid JSON representation of a Zarr v2 data type spec. - if not check_dtype_spec_v2(data): - return False - - # Check that the object codec id is appropriate for variable-length bytes strings. - if data["name"] != "|O": - return False - return data["object_codec_id"] == cls.object_codec_id + return guard_type(data, VariableLengthBytesJSON_V2) @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["variable_length_bytes"]]: diff --git a/src/zarr/core/dtype/npy/complex.py b/src/zarr/core/dtype/npy/complex.py index 2f432a9e0a..e9bbfc2186 100644 --- a/src/zarr/core/dtype/npy/complex.py +++ b/src/zarr/core/dtype/npy/complex.py @@ -18,7 +18,6 @@ DTypeJSON, HasEndianness, HasItemSize, - check_dtype_spec_v2, ) from zarr.core.dtype.npy.common import ( ComplexLike, @@ -34,6 +33,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -105,11 +105,7 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]] bool True if the input is a valid JSON representation, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and data["name"] in cls._zarr_v2_names - and data["object_codec_id"] is None - ) + return guard_type(data, DTypeConfig_V2[str, None]) and data["name"] in cls._zarr_v2_names @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[str]: diff --git a/src/zarr/core/dtype/npy/float.py b/src/zarr/core/dtype/npy/float.py index 3113bc5b61..598002286d 100644 --- a/src/zarr/core/dtype/npy/float.py +++ b/src/zarr/core/dtype/npy/float.py @@ -11,7 +11,6 @@ DTypeJSON, HasEndianness, HasItemSize, - check_dtype_spec_v2, ) from zarr.core.dtype.npy.common import ( FloatLike, @@ -27,6 +26,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -89,11 +89,7 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]] TypeGuard[DTypeConfig_V2[str, None]] True if the input is a valid JSON representation of this data type, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and data["name"] in cls._zarr_v2_names - and data["object_codec_id"] is None - ) + return guard_type(data, DTypeConfig_V2[str, None]) and data["name"] in cls._zarr_v2_names @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[str]: diff --git a/src/zarr/core/dtype/npy/int.py b/src/zarr/core/dtype/npy/int.py index 01a79142a3..4991541453 100644 --- a/src/zarr/core/dtype/npy/int.py +++ b/src/zarr/core/dtype/npy/int.py @@ -21,7 +21,6 @@ DTypeJSON, HasEndianness, HasItemSize, - check_dtype_spec_v2, ) from zarr.core.dtype.npy.common import ( check_json_int, @@ -29,6 +28,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -83,11 +83,7 @@ def _check_json_v2(cls, data: object) -> TypeGuard[DTypeConfig_V2[str, None]]: False otherwise. """ - return ( - check_dtype_spec_v2(data) - and data["name"] in cls._zarr_v2_names - and data["object_codec_id"] is None - ) + return guard_type(data, DTypeConfig_V2[str, None]) and data["name"] in cls._zarr_v2_names @classmethod def _check_json_v3(cls, data: object) -> TypeGuard[str]: diff --git a/src/zarr/core/dtype/npy/string.py b/src/zarr/core/dtype/npy/string.py index 32375a1c71..a39a7996eb 100644 --- a/src/zarr/core/dtype/npy/string.py +++ b/src/zarr/core/dtype/npy/string.py @@ -25,7 +25,6 @@ HasItemSize, HasLength, HasObjectCodec, - check_dtype_spec_v2, v3_unstable_dtype_warning, ) from zarr.core.dtype.npy.common import ( @@ -34,6 +33,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TDType_co, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -189,10 +189,8 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[FixedLengthUTF32JSON_V2]: Whether the input is a valid JSON representation of a NumPy U dtype. """ return ( - check_dtype_spec_v2(data) - and isinstance(data["name"], str) + guard_type(data, FixedLengthUTF32JSON_V2) and re.match(r"^[><]U\d+$", data["name"]) is not None - and data["object_codec_id"] is None ) @classmethod @@ -520,11 +518,7 @@ def _check_json_v2( Whether the input is a valid JSON representation of a NumPy "object" data type, and that the object codec id is appropriate for variable-length UTF-8 strings. """ - return ( - check_dtype_spec_v2(data) - and data["name"] == "|O" - and data["object_codec_id"] == cls.object_codec_id - ) + return guard_type(data, VariableLengthUTF8JSON_V2) @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["variable_length_utf8"]]: diff --git a/src/zarr/core/dtype/npy/structured.py b/src/zarr/core/dtype/npy/structured.py index a0e3b0fbd4..2ea78f6612 100644 --- a/src/zarr/core/dtype/npy/structured.py +++ b/src/zarr/core/dtype/npy/structured.py @@ -1,20 +1,17 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Literal, Self, TypeGuard, cast, overload import numpy as np -from zarr.core.common import NamedConfig +from zarr.core.common import NamedConfig, NamedRequiredConfig, StructuredName_V2 from zarr.core.dtype.common import ( DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasItemSize, - StructuredName_V2, - check_dtype_spec_v2, - check_structured_dtype_name_v2, v3_unstable_dtype_warning, ) from zarr.core.dtype.npy.common import ( @@ -23,6 +20,7 @@ check_json_str, ) from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType +from zarr.core.type_check import guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -58,7 +56,10 @@ class StructuredJSON_V2(DTypeConfig_V2[StructuredName_V2, None]): class StructuredJSON_V3( - NamedConfig[Literal["structured"], dict[str, Sequence[Sequence[str | DTypeJSON]]]] + NamedRequiredConfig[ + Literal["structured"], + Mapping[str, Sequence[list[str | NamedConfig[str, Mapping[str, object]]]]], + ] ): """ A JSON representation of a structured data type in Zarr V3. @@ -211,12 +212,7 @@ def _check_json_v2( True if the input is a valid JSON representation of a Structured data type for Zarr V2, False otherwise. """ - return ( - check_dtype_spec_v2(data) - and not isinstance(data["name"], str) - and check_structured_dtype_name_v2(data["name"]) - and data["object_codec_id"] is None - ) + return guard_type(data, StructuredJSON_V2) @classmethod def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[StructuredJSON_V3]: @@ -235,13 +231,7 @@ def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[StructuredJSON_V3]: False otherwise. """ - return ( - isinstance(data, dict) - and set(data.keys()) == {"name", "configuration"} - and data["name"] == cls._zarr_v3_name - and isinstance(data["configuration"], dict) - and set(data["configuration"].keys()) == {"fields"} - ) + return guard_type(data, StructuredJSON_V3) @classmethod def _from_json_v2(cls, data: DTypeJSON) -> Self: diff --git a/src/zarr/core/dtype/npy/time.py b/src/zarr/core/dtype/npy/time.py index d523e16940..8c6b06a2c1 100644 --- a/src/zarr/core/dtype/npy/time.py +++ b/src/zarr/core/dtype/npy/time.py @@ -25,7 +25,6 @@ DTypeJSON, HasEndianness, HasItemSize, - check_dtype_spec_v2, ) from zarr.core.dtype.npy.common import ( DATETIME_UNIT, @@ -35,6 +34,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.core.type_check import check_type, guard_type if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat @@ -377,13 +377,13 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[TimeDelta64JSON_V2]: True if the JSON input is a valid representation of this class, otherwise False. """ - if not check_dtype_spec_v2(data): + if not guard_type(data, TimeDelta64JSON_V2): return False + # We now know it's TimeDelta64JSON_V2, but there are constraints on the name that can't be + # expressed via type annotations name = data["name"] # match m[M], etc # consider making this a standalone function - if not isinstance(name, str): - return False if not name.startswith(cls._zarr_v2_names): return False if len(name) == 3: @@ -394,7 +394,7 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[TimeDelta64JSON_V2]: return name[4:-1].endswith(DATETIME_UNIT) and name[-1] == "]" @classmethod - def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSON_V3]: + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[TimeDelta64JSON_V3]: """ Check that the input is a valid JSON representation of this class in Zarr V3. @@ -404,13 +404,7 @@ def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSON_V3]: True if the JSON input is a valid representation of this class, otherwise False. """ - return ( - isinstance(data, dict) - and set(data.keys()) == {"name", "configuration"} - and data["name"] == cls._zarr_v3_name - and isinstance(data["configuration"], dict) - and set(data["configuration"].keys()) == {"unit", "scale_factor"} - ) + return check_type(data, TimeDelta64JSON_V3).success @classmethod def _from_json_v2(cls, data: DTypeJSON) -> Self: @@ -644,11 +638,9 @@ def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSON_V2]: True if the input is a valid JSON representation of a NumPy datetime64 data type, otherwise False. """ - if not check_dtype_spec_v2(data): + if not guard_type(data, DateTime64JSON_V2): return False name = data["name"] - if not isinstance(name, str): - return False if not name.startswith(cls._zarr_v2_names): return False if len(name) == 3: @@ -673,14 +665,7 @@ def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[DateTime64JSON_V3]: TypeGuard[DateTime64JSON_V3] True if the input is a valid JSON representation of a numpy datetime64 data type in Zarr V3, False otherwise. """ - - return ( - isinstance(data, dict) - and set(data.keys()) == {"name", "configuration"} - and data["name"] == cls._zarr_v3_name - and isinstance(data["configuration"], dict) - and set(data["configuration"].keys()) == {"unit", "scale_factor"} - ) + return check_type(data, DateTime64JSON_V3).success @classmethod def _from_json_v2(cls, data: DTypeJSON) -> Self: diff --git a/src/zarr/core/dtype/wrapper.py b/src/zarr/core/dtype/wrapper.py index 776aea81d8..fefe92fb23 100644 --- a/src/zarr/core/dtype/wrapper.py +++ b/src/zarr/core/dtype/wrapper.py @@ -39,8 +39,8 @@ import numpy as np if TYPE_CHECKING: - from zarr.core.common import JSON, ZarrFormat - from zarr.core.dtype.common import DTypeJSON, DTypeSpec_V2, DTypeSpec_V3 + from zarr.core.common import JSON, DTypeSpec_V3, ZarrFormat + from zarr.core.dtype.common import DTypeJSON, DTypeSpec_V2 # This the upper bound for the scalar types we support. It's numpy scalars + str, # because the new variable-length string dtype in numpy does not have a corresponding scalar type diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 15a256fb5d..2843206b00 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -7,9 +7,9 @@ import unicodedata import warnings from collections import defaultdict -from dataclasses import asdict, dataclass, field, fields, replace +from dataclasses import asdict, dataclass, field, replace from itertools import accumulate -from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload +from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload import numpy as np import numpy.typing as npt @@ -41,7 +41,13 @@ ZATTRS_JSON, ZGROUP_JSON, ZMETADATA_V2_JSON, + ArrayMetadataJSON_V2, + ArrayMetadataJSON_V3, + ConsolidatedMetadata_JSON_V2, + ConsolidatedMetadata_JSON_V3, DimensionNames, + GroupMetadataJSON_V2, + GroupMetadataJSON_V3, NodeType, ShapeLike, ZarrFormat, @@ -144,7 +150,7 @@ class ConsolidatedMetadata: kind: Literal["inline"] = "inline" must_understand: Literal[False] = False - def to_dict(self) -> dict[str, JSON]: + def to_dict(self) -> ConsolidatedMetadata_JSON_V2 | ConsolidatedMetadata_JSON_V3: return { "kind": self.kind, "must_understand": self.must_understand, @@ -158,13 +164,12 @@ def to_dict(self) -> dict[str, JSON]: ), ) }, - } + } # type: ignore[return-value, misc] @classmethod - def from_dict(cls, data: dict[str, JSON]) -> ConsolidatedMetadata: - data = dict(data) + def from_dict(cls, data: ConsolidatedMetadata_JSON_V2 | ConsolidatedMetadata_JSON_V3) -> Self: + kind = data["kind"] - kind = data.get("kind") if kind != "inline": raise ValueError(f"Consolidated metadata kind='{kind}' is not supported.") @@ -180,20 +185,21 @@ def from_dict(cls, data: dict[str, JSON]) -> ConsolidatedMetadata: f"Invalid value for metadata items. key='{k}', type='{type(v).__name__}'" ) - # zarr_format is present in v2 and v3. - zarr_format = parse_zarr_format(v["zarr_format"]) + zarr_format = v["zarr_format"] if zarr_format == 3: - node_type = parse_node_type(v.get("node_type", None)) + v = cast(ArrayMetadataJSON_V3 | GroupMetadataJSON_V3, v) + node_type = v["node_type"] if node_type == "group": - metadata[k] = GroupMetadata.from_dict(v) + metadata[k] = GroupMetadata.from_dict(v) # type: ignore[arg-type] elif node_type == "array": - metadata[k] = ArrayV3Metadata.from_dict(v) + metadata[k] = ArrayV3Metadata.from_dict(v) # type: ignore[arg-type] else: assert_never(node_type) elif zarr_format == 2: + v = cast(ArrayMetadataJSON_V2 | GroupMetadataJSON_V2, v) if "shape" in v: - metadata[k] = ArrayV2Metadata.from_dict(v) + metadata[k] = ArrayV2Metadata.from_dict(v) # type: ignore[arg-type] else: metadata[k] = GroupMetadata.from_dict(v) else: @@ -409,22 +415,21 @@ def __init__( object.__setattr__(self, "consolidated_metadata", consolidated_metadata) @classmethod - def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: - data = dict(data) - assert data.pop("node_type", None) in ("group", None) - consolidated_metadata = data.pop("consolidated_metadata", None) - if consolidated_metadata: - data["consolidated_metadata"] = ConsolidatedMetadata.from_dict(consolidated_metadata) - - zarr_format = data.get("zarr_format") - if zarr_format == 2 or zarr_format is None: - # zarr v2 allowed arbitrary keys here. - # We don't want the GroupMetadata constructor to fail just because someone put an - # extra key in the metadata. - expected = {x.name for x in fields(cls)} - data = {k: v for k, v in data.items() if k in expected} - - return cls(**data) + def from_dict(cls, data: GroupMetadataJSON_V2 | GroupMetadataJSON_V3) -> GroupMetadata: # type: ignore[override] + """ + Create an instance of GroupMetadata from a dict model of Zarr group metadata. + """ + if "consolidated_metadata" in data and data["consolidated_metadata"] is not None: + consolidated_metadata = ConsolidatedMetadata.from_dict(data["consolidated_metadata"]) + else: + consolidated_metadata = None + zarr_format = data["zarr_format"] + attributes = data.get("attributes", {}) + return cls( + attributes=attributes, # type: ignore[arg-type] + zarr_format=zarr_format, + consolidated_metadata=consolidated_metadata, + ) def to_dict(self) -> dict[str, Any]: result = asdict(replace(self, consolidated_metadata=None)) @@ -674,7 +679,7 @@ def from_dict( data: dict[str, Any], ) -> AsyncGroup: return cls( - metadata=GroupMetadata.from_dict(data), + metadata=GroupMetadata.from_dict(data), # type: ignore[arg-type] store_path=store_path, ) @@ -3557,9 +3562,9 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet raise MetadataValidationError(msg) match zarr_json: case {"node_type": "array"}: - return ArrayV3Metadata.from_dict(zarr_json) + return ArrayV3Metadata.from_dict(zarr_json) # type: ignore[arg-type] case {"node_type": "group"}: - return GroupMetadata.from_dict(zarr_json) + return GroupMetadata.from_dict(zarr_json) # type: ignore[arg-type] case _: # pragma: no cover raise ValueError( "invalid value for `node_type` key in metadata document" @@ -3576,7 +3581,7 @@ def _build_metadata_v2( case {"shape": _}: return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) case _: # pragma: no cover - return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) + return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) # type: ignore[arg-type] @overload diff --git a/src/zarr/core/metadata/__init__.py b/src/zarr/core/metadata/__init__.py index 43b5ec98fe..4975eacf96 100644 --- a/src/zarr/core/metadata/__init__.py +++ b/src/zarr/core/metadata/__init__.py @@ -1,17 +1,13 @@ from typing import TypeAlias, TypeVar -from .v2 import ArrayV2Metadata, ArrayV2MetadataDict -from .v3 import ArrayV3Metadata, ArrayV3MetadataDict +from .v2 import ArrayV2Metadata +from .v3 import ArrayV3Metadata ArrayMetadata: TypeAlias = ArrayV2Metadata | ArrayV3Metadata -ArrayMetadataDict: TypeAlias = ArrayV2MetadataDict | ArrayV3MetadataDict T_ArrayMetadata = TypeVar("T_ArrayMetadata", ArrayV2Metadata, ArrayV3Metadata) __all__ = [ "ArrayMetadata", - "ArrayMetadataDict", "ArrayV2Metadata", - "ArrayV2MetadataDict", "ArrayV3Metadata", - "ArrayV3MetadataDict", ] diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 3204543426..7f108d2680 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -10,6 +10,7 @@ from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.dtype import get_data_type_from_json from zarr.core.dtype.common import OBJECT_CODEC_IDS, DTypeSpec_V2 +from zarr.core.type_check import guard_type from zarr.errors import ZarrUserWarning from zarr.registry import get_numcodec @@ -38,6 +39,7 @@ JSON, ZARRAY_JSON, ZATTRS_JSON, + CodecJSON_V2, MemoryOrder, parse_shapelike, ) @@ -273,8 +275,8 @@ def parse_filters(data: object) -> tuple[Numcodec, ...] | None: for idx, val in enumerate(data): if _is_numcodec(val): out.append(val) - elif isinstance(val, dict): - out.append(get_numcodec(val)) # type: ignore[arg-type] + elif guard_type(val, CodecJSON_V2[str]): + out.append(get_numcodec(val)) else: msg = f"Invalid filter at index {idx}. Expected a numcodecs.abc.Codec or a dict representation of numcodecs.abc.Codec. Got {type(val)} instead." raise TypeError(msg) @@ -296,8 +298,8 @@ def parse_compressor(data: object) -> Numcodec | None: """ if data is None or _is_numcodec(data): return data - if isinstance(data, dict): - return get_numcodec(data) # type: ignore[arg-type] + if guard_type(data, CodecJSON_V2[str]): + return get_numcodec(data) msg = f"Invalid compressor. Expected None, a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead." raise ValueError(msg) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 649a490409..4903dded91 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypedDict +from collections.abc import Sequence +from typing import TYPE_CHECKING from zarr.abc.metadata import Metadata from zarr.core.buffer.core import default_buffer_prototype from zarr.core.dtype import VariableLengthUTF8, ZDType, get_data_type_from_json -from zarr.core.dtype.common import check_dtype_spec_v3 +from zarr.errors import UnknownCodecError if TYPE_CHECKING: from typing import Self @@ -28,30 +29,16 @@ from zarr.core.common import ( JSON, ZARR_JSON, + ArrayMetadataJSON_V3, DimensionNames, parse_named_configuration, parse_shapelike, ) from zarr.core.config import config from zarr.core.metadata.common import parse_attributes -from zarr.errors import MetadataValidationError, NodeTypeValidationError, UnknownCodecError from zarr.registry import get_codec_class -def parse_zarr_format(data: object) -> Literal[3]: - if data == 3: - return 3 - msg = f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'." - raise MetadataValidationError(msg) - - -def parse_node_type_array(data: object) -> Literal["array"]: - if data == "array": - return "array" - msg = f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'." - raise NodeTypeValidationError(msg) - - def parse_codecs(data: object) -> tuple[Codec, ...]: out: tuple[Codec, ...] = () @@ -74,6 +61,35 @@ def parse_codecs(data: object) -> tuple[Codec, ...]: return out +def parse_dimension_names(data: DimensionNames) -> tuple[str | None, ...] | None: + if data is None: + return None + return tuple(data) + + +def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]: + """ + Parse storage_transformers. Zarr python cannot use storage transformers + at this time, so this function doesn't attempt to validate them. + """ + + if data is None: + return () + + if isinstance(data, Sequence): + if len(data) > 0: + msg = ( + f"Array metadata contains storage transformers: {data}." + "Arrays with storage transformers are not supported in zarr-python at this time." + ) + raise ValueError(msg) + else: + return () + raise TypeError( + f"Invalid storage_transformers. Expected an iterable of dicts. Got {type(data)} instead." + ) + + def validate_array_bytes_codec(codecs: tuple[Codec, ...]) -> ArrayBytesCodec: # ensure that we have at least one ArrayBytesCodec abcs: list[ArrayBytesCodec] = [codec for codec in codecs if isinstance(codec, ArrayBytesCodec)] @@ -105,42 +121,6 @@ def validate_codecs(codecs: tuple[Codec, ...], dtype: ZDType[TBaseDType, TBaseSc ) -def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: - if data is None: - return data - elif isinstance(data, Iterable) and all(isinstance(x, type(None) | str) for x in data): - return tuple(data) - else: - msg = f"Expected either None or a iterable of str, got {type(data)}" - raise TypeError(msg) - - -def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]: - """ - Parse storage_transformers. Zarr python cannot use storage transformers - at this time, so this function doesn't attempt to validate them. - """ - if data is None: - return () - if isinstance(data, Iterable): - if len(tuple(data)) >= 1: - return data # type: ignore[return-value] - else: - return () - raise TypeError( - f"Invalid storage_transformers. Expected an iterable of dicts. Got {type(data)} instead." - ) - - -class ArrayV3MetadataDict(TypedDict): - """ - A typed dictionary model for zarr v3 metadata. - """ - - zarr_format: Literal[3] - attributes: dict[str, JSON] - - @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(Metadata): shape: tuple[int, ...] @@ -297,43 +277,45 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: ) } + # this type annotation violates liskov but that's a problem we need to fix at the base class @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: - # make a copy because we are modifying the dict - _data = data.copy() - - # check that the zarr_format attribute is correct - _ = parse_zarr_format(_data.pop("zarr_format")) - # check that the node_type attribute is correct - _ = parse_node_type_array(_data.pop("node_type")) - - data_type_json = _data.pop("data_type") - if not check_dtype_spec_v3(data_type_json): - raise ValueError(f"Invalid data_type: {data_type_json!r}") + def from_dict(cls, data: ArrayMetadataJSON_V3) -> Self: # type: ignore[override] + data_type_json = data["data_type"] data_type = get_data_type_from_json(data_type_json, zarr_format=3) # check that the fill value is consistent with the data type + fill = data["fill_value"] try: - fill = _data.pop("fill_value") - fill_value_parsed = data_type.from_json_scalar(fill, zarr_format=3) + fill_value_parsed = data_type.from_json_scalar(fill, zarr_format=3) # type: ignore[arg-type] except ValueError as e: raise TypeError(f"Invalid fill_value: {fill!r}") from e # dimension_names key is optional, normalize missing to `None` - _data["dimension_names"] = _data.pop("dimension_names", None) + dimension_names = data.get("dimension_names", None) # attributes key is optional, normalize missing to `None` - _data["attributes"] = _data.pop("attributes", None) - - return cls(**_data, fill_value=fill_value_parsed, data_type=data_type) # type: ignore[arg-type] + attributes = data.get("attributes", None) + + # storage transformers key is optional, normalize missing to `None` + storage_transformers = data.get("storage_transformers", None) + + return cls( + shape=data["shape"], + chunk_grid=data["chunk_grid"], # type: ignore[arg-type] + chunk_key_encoding=data["chunk_key_encoding"], # type: ignore[arg-type] + codecs=data["codecs"], # type: ignore[arg-type] + attributes=attributes, # type: ignore[arg-type] + data_type=data_type, + fill_value=fill_value_parsed, + dimension_names=dimension_names, + storage_transformers=storage_transformers, # type: ignore[arg-type] + ) - def to_dict(self) -> dict[str, JSON]: + def to_dict(self) -> ArrayMetadataJSON_V3: # type: ignore[override] out_dict = super().to_dict() out_dict["fill_value"] = self.data_type.to_json_scalar( self.fill_value, zarr_format=self.zarr_format ) - if not isinstance(out_dict, dict): - raise TypeError(f"Expected dict. Got {type(out_dict)}.") # if `dimension_names` is `None`, we do not include it in # the metadata document @@ -348,7 +330,7 @@ def to_dict(self) -> dict[str, JSON]: if isinstance(dtype_meta, ZDType): out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable] - return out_dict + return out_dict # type: ignore[return-value] def update_shape(self, shape: tuple[int, ...]) -> Self: return replace(self, shape=shape) diff --git a/src/zarr/core/type_check.py b/src/zarr/core/type_check.py new file mode 100644 index 0000000000..8587594cba --- /dev/null +++ b/src/zarr/core/type_check.py @@ -0,0 +1,504 @@ +import collections +import collections.abc +import sys +import types +import typing +from dataclasses import dataclass +from typing import ( + Any, + ForwardRef, + Literal, + NotRequired, + TypeGuard, + TypeVar, + cast, + get_args, + get_origin, + get_type_hints, +) + +from typing_extensions import ReadOnly, evaluate_forward_ref + + +@dataclass(frozen=True) +class TypeCheckResult: + """ + Result of a type-checking operation. + """ + + success: bool + errors: list[str] + + +# ---------- helpers ---------- +def _type_name(tp: Any) -> str: + """Get a readable name for a type hint.""" + if isinstance(tp, type): + return tp.__name__ + return str(tp) + + +def _is_typeddict_class(tp: object) -> bool: + """ + Check if a type is a TypedDict class. + """ + return isinstance(tp, type) and hasattr(tp, "__annotations__") and hasattr(tp, "__total__") + + +def _find_generic_typeddict_base(cls: type) -> tuple[type | None, tuple[Any, ...] | None]: + """ + Find the base class of a generic TypedDict class. + + This is necessary because the `__origin__` of a TypedDict is always `dict` + and the `__args__` is always `(, )`. The actual base class is stored in + `__orig_bases__`. + + Returns a tuple of `(base, args)` where `base` is the base class and `args` + is a tuple of arguments to the base class (i.e. the key and value types of + the TypedDict). + + Returns `(None, None)` if no base class is found. + """ + for base in getattr(cls, "__orig_bases__", ()): + origin = get_origin(base) + if origin is None: + continue + if isinstance(origin, type) and hasattr(origin, "__annotations__"): + return origin, get_args(base) + return None, None + + +def _resolve_type( + tp: Any, + type_map: dict[TypeVar, Any] | None = None, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, + _seen: set[Any] | None = None, +) -> Any: + """ + Resolve type hints and ForwardRef. Maintains a cache of resolved types to avoid infinite recursion. + """ + if _seen is None: + _seen = set() + + # Use a more robust tracking mechanism + type_repr = repr(tp) + if type_repr in _seen: + # Return Any for recursive types to break the cycle + return Any + + _seen.add(type_repr) + + try: + return _resolve_type_impl(tp, type_map, globalns, localns, _seen) + finally: + _seen.discard(type_repr) + + +def _resolve_type_impl( + tp: Any, + type_map: dict[TypeVar, Any] | None, + globalns: dict[str, Any] | None, + localns: dict[str, Any] | None, + _seen: set[str], +) -> Any: + """ + Internal implementation of type resolution. + """ + # Substitute TypeVar + if isinstance(tp, TypeVar): + resolved = type_map.get(tp, tp) if type_map else tp + if isinstance(resolved, TypeVar) and resolved is tp: + return tp # <-- keep literal TypeVar until check + return _resolve_type(resolved, type_map, globalns, localns, _seen) + + # Handle string-based unions safely + if isinstance(tp, str) and " | " in tp: + parts = [p.strip() for p in tp.split("|")] + resolved_parts = tuple( + _resolve_type( + eval(p, globalns or {}, localns or {}), type_map, globalns, localns, _seen + ) + for p in parts + ) + return typing.Union[resolved_parts] # noqa: UP007 + + # Evaluate ForwardRef + if isinstance(tp, (ForwardRef, str)): + ref = tp if isinstance(tp, ForwardRef) else ForwardRef(tp) + # Use frozenset to avoid issues with mutable default arguments + tp = evaluate_forward_ref(ref, globals=globalns, locals=localns) + + # Recurse into Literal + origin = get_origin(tp) + args = get_args(tp) + if origin is Literal: + # Pass literal arguments through as-is, they are values, not types to resolve. + return Literal.__getitem__(args) + + # Handle types.UnionType (Python 3.10+ union syntax like str | int) + if isinstance(tp, types.UnionType): + # Don't try to reconstruct UnionType, convert to typing.Union + resolved_args = tuple(_resolve_type(a, type_map, globalns, localns, _seen) for a in args) + return typing.Union[resolved_args] # noqa: UP007 + + # Recurse into other generics + if origin and args: + new_args = tuple(_resolve_type(a, type_map, globalns, localns, _seen) for a in args) + # Special handling for single-argument generics like NotRequired, ReadOnly + if len(new_args) == 1: + return origin[new_args[0]] # Pass single argument, not tuple + else: + return origin[new_args] # Pass tuple for multi-argument generics + + return tp + + +def check_type( + obj: Any, expected_type: type | types.UnionType | ForwardRef | None, path: str = "value" +) -> TypeCheckResult: + """ + Check if `obj` is of type `expected_type`. + """ + origin = get_origin(expected_type) + + if origin in (NotRequired, ReadOnly): + args = get_args(expected_type) + inner_type = args[0] if args else Any + return check_type(obj, inner_type, path) + + if expected_type is Any: + return TypeCheckResult(True, []) + + if origin in (typing.Union, types.UnionType): + return check_union(obj, expected_type, path) + + if origin is typing.Literal: + return check_literal(obj, expected_type, path) + + if expected_type is None or expected_type is type(None): + return check_none(obj, path) + + # Check for TypedDict (now unified) + if (origin and _is_typeddict_class(origin)) or _is_typeddict_class(expected_type): + return check_typeddict(obj, expected_type, path) + + if origin is tuple: + return check_tuple(obj, expected_type, path) + + if origin in (collections.abc.Sequence, list): + return check_sequence_or_list(obj, expected_type, path) + + if origin in (dict, typing.Mapping, collections.abc.Mapping) or expected_type in ( + dict, + typing.Mapping, + collections.abc.Mapping, + ): + return check_mapping(obj, expected_type, path) + + if expected_type is int: + return check_int(obj, path) + + if expected_type in (float, str, bool): + return check_primitive(obj, expected_type, path) # type: ignore[arg-type] + + # Fallback + try: + if isinstance(obj, expected_type): # type: ignore[arg-type] + return TypeCheckResult(True, []) + tn = _type_name(expected_type) + return TypeCheckResult(False, [f"{path} expected {tn} but got {type(obj).__name__}"]) + except TypeError: + return TypeCheckResult(False, [f"{path} cannot be checked against {expected_type}"]) + + +T = TypeVar("T") + + +def ensure_type(obj: object, expected_type: type[T], path: str = "value") -> T: + """ + Check if obj is assignable to expected type. If so, return obj. Otherwise a TypeError is raised. + """ + if check_type(obj, expected_type, path).success: + return cast(T, obj) + raise TypeError( + f"Expected an instance of {expected_type} but got {obj!r} with type {type(obj)}" + ) + + +def guard_type(obj: object, expected_type: type[T], path: str = "value") -> TypeGuard[T]: + """ + A type guard function that checks if obj is assignable to expected type. + """ + return check_type(obj, expected_type, path).success + + +def check_typeddict( + obj: Any, + td_type: Any, + path: str, +) -> TypeCheckResult: + """ + Check if an object matches a TypedDict, handling both generic + and non-generic cases. + + This function determines if the provided TypedDict is a generic + with parameters (e.g., MyTD[str]) or a regular class, and then + performs a unified validation check. + """ + if not isinstance(obj, dict): + return TypeCheckResult( + False, [f"{path} expected dict for TypedDict but got {type(obj).__name__}"] + ) + + # --- Now get the metadata in a single, unified step --- + td_cls, type_map, globalns, localns = _get_typeddict_metadata(td_type) + + if td_cls is None: + # Fallback if it's not a TypedDict type at all + return TypeCheckResult(False, [f"{path} expected a TypedDict but got {td_type!r}"]) + + if type_map is not None and len(getattr(td_cls, "__parameters__", ())) != len( + get_args(td_type) + ): + return TypeCheckResult(False, [f"{path} type parameter count mismatch"]) + + if type_map is None and len(get_args(td_type)) > 0: + base_origin, base_args = _find_generic_typeddict_base(td_cls) + if ( + base_origin is not None + and base_args is not None + and len(getattr(base_origin, "__parameters__", ())) != len(base_args) + ): + return TypeCheckResult(False, [f"{path} type parameter count mismatch in generic base"]) + + # --- Now call the shared validation logic --- + errors = _validate_typeddict_fields(obj, td_cls, type_map, globalns, localns, path) + + return TypeCheckResult(not errors, errors) + + +def check_mapping(obj: Any, expected_type: Any, path: str) -> TypeCheckResult: + """ + Check if an object is assignable to a mapping type. + """ + if not isinstance(obj, collections.abc.Mapping): + return TypeCheckResult( + False, [f"{path} expected collections.abc.Mapping but got {type(obj).__name__}"] + ) + args = get_args(expected_type) + key_t = args[0] if args else Any + val_t = args[1] if len(args) > 1 else Any + errors: list[str] = [] + for k, v in obj.items(): + rk = check_type(k, key_t, f"{path}[key {k!r}]") + rv = check_type(v, val_t, f"{path}[{k!r}]") + if not rk.success: + errors.extend(rk.errors) + if not rv.success: + errors.extend(rv.errors) + return TypeCheckResult(len(errors) == 0, errors) + + +def check_sequence_or_list(obj: Any, expected_type: Any, path: str) -> TypeCheckResult: + """ + Check if an object is assignable to a sequence or list type. + """ + args = get_args(expected_type) + if not isinstance(obj, typing.Sequence | collections.abc.Sequence) or isinstance( + obj, (str, bytes) + ): + return TypeCheckResult(False, [f"{path} expected sequence but got {type(obj).__name__}"]) + elem_type = args[0] if args else Any + errors: list[str] = [] + for i, item in enumerate(obj): + res = check_type(item, elem_type, f"{path}[{i}]") + if not res.success: + errors.extend(res.errors) + return TypeCheckResult(len(errors) == 0, errors) + + +def check_union(obj: Any, expected_type: Any, path: str) -> TypeCheckResult: + """ + Check if an object is assignable to a union type. + """ + args = get_args(expected_type) + errors: list[str] = [] + for arg in args: + res = check_type(obj, arg, path) + if res.success: + return TypeCheckResult(True, []) + errors.extend(res.errors) + return TypeCheckResult(False, errors or [f"{path} did not match any type in {expected_type}"]) + + +def check_tuple(obj: Any, expected_type: Any, path: str) -> TypeCheckResult: + """ + Check if an object is assignable to a tuple type. + """ + if not isinstance(obj, tuple): + return TypeCheckResult(False, [f"{path} expected tuple but got {type(obj).__name__}"]) + args = get_args(expected_type) + targs = args + errors: list[str] = [] + + # Variadic tuple like tuple[int, ...] + if len(targs) == 2 and targs[1] is Ellipsis: + elem_t = targs[0] + for i, item in enumerate(obj): + res = check_type(item, elem_t, f"{path}[{i}]") + if not res.success: + errors.extend(res.errors) + return TypeCheckResult(len(errors) == 0, errors) + + # Fixed-length tuple like tuple[int, str, None] + if len(obj) != len(targs): + return TypeCheckResult( + False, [f"{path} expected tuple of length {len(targs)} but got {len(obj)}"] + ) + for i, (item, tp) in enumerate(zip(obj, targs, strict=False)): + expected = type(None) if tp is None else tp + res = check_type(item, expected, f"{path}[{i}]") + if not res.success: + errors.extend(res.errors) + return TypeCheckResult(len(errors) == 0, errors) + + +def check_literal(obj: object, expected_type: Any, path: str) -> TypeCheckResult: + """ + Check if an object is assignable to a literal type. + """ + allowed = get_args(expected_type) + if obj in allowed: + return TypeCheckResult(True, []) + msg = f"{path} expected literal in {allowed} but got {obj!r}" + return TypeCheckResult(False, [msg]) + + +def check_none(obj: object, path: str) -> TypeCheckResult: + """ + Check if an object is None. + """ + if obj is None: + return TypeCheckResult(True, []) + msg = f"{path} expected None but got {obj!r}" + return TypeCheckResult(False, [msg]) + + +def check_primitive(obj: object, expected_type: type, path: str) -> TypeCheckResult: + """ + Check if an object is a primitive type, i.e. a type where isinstance(obj, type) will work. + """ + if isinstance(obj, expected_type): + return TypeCheckResult(True, []) + msg = f"{path} expected an instance of {expected_type} but got {obj!r} with type {type(obj)}" + return TypeCheckResult(False, [msg]) + + +def check_int(obj: object, path: str) -> TypeCheckResult: + """ + Check if an object is an int. + """ + if isinstance(obj, int) and not isinstance(obj, bool): # bool is a subclass of int + return TypeCheckResult(True, []) + msg = f"{path} expected int but got {obj!r} with type {type(obj)}" + return TypeCheckResult(False, [msg]) + + +def _get_typeddict_metadata( + td_type: Any, +) -> tuple[ + type | None, + dict[TypeVar, Any] | None, + dict[str, Any] | None, + dict[str, Any] | None, +]: + """ + Extracts the TypedDict class, type variable map, and namespaces. + """ + origin = get_origin(td_type) + + if origin and _is_typeddict_class(origin): + td_cls = origin + args = get_args(td_type) + tvars = getattr(td_cls, "__parameters__", ()) + type_map = dict(zip(tvars, args, strict=False)) + + # Enhanced namespace resolution - include calling frame locals + mod = sys.modules.get(td_cls.__module__) + globalns = vars(mod) if mod else {} + localns = dict(vars(td_cls)) + + return td_cls, type_map, globalns, localns + + elif _is_typeddict_class(td_type): + td_cls = td_type + base_origin, base_args = _find_generic_typeddict_base(td_cls) + if base_origin is not None: + tvars = getattr(base_origin, "__parameters__", ()) + type_map = dict(zip(tvars, base_args, strict=False)) # type: ignore[arg-type] + + mod = sys.modules.get(base_origin.__module__) + globalns = vars(mod) if mod else {} + localns = dict(vars(base_origin)) + else: + type_map = None + mod = sys.modules.get(td_cls.__module__) + globalns = vars(mod) if mod else {} + localns = dict(vars(td_cls)) + + return td_cls, type_map, globalns, localns + + return None, None, None, None + + +def _validate_typeddict_fields( + obj: Any, + td_cls: type, + type_map: dict[TypeVar, Any] | None, + globalns: dict[str, Any] | None, + localns: dict[str, Any] | None, + path: str, +) -> list[str]: + """ + Validates the fields of a dictionary against a TypedDict's annotations. + """ + annotations = get_type_hints(td_cls, globalns=globalns, localns=localns, include_extras=True) + errors: list[str] = [] + is_total_false = getattr(td_cls, "__total__", True) is False + for key, typ in annotations.items(): + # Check if the key is not present in the object + if key not in obj: + # If total=False, all fields are optional unless explicitly Required + if is_total_false: + continue + + # Check the chain of parametrized types for a NotRequired. + # We only need to look at the first parameter. + is_optional = False + if get_origin(typ) == NotRequired: + is_optional = True + else: + sub_args = get_args(typ) + while len(sub_args) > 0: + if get_origin(sub_args[0]) == NotRequired: + is_optional = True + break + sub_args = get_args(sub_args[0]) + + if not is_optional: + errors.append(f"{path} missing required key '{key}'") + continue + + # we have to further resolve this type because get_type_hints does not resolve + # generic aliases + resolved_typ = _resolve_type(typ, type_map, globalns=globalns, localns=localns) + res = check_type(obj[key], resolved_typ, f"{path}['{key}']") + if not res.success: + errors.extend(res.errors) + + # We allow extra keys of any type right now + # when PEP 728 is done, then we can refine this and do a type check on the keys + # errors.extend([f"{path} has unexpected key '{key}'" for key in obj if key not in annotations]) + + return errors diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 46216205f7..cff67fe013 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -17,12 +17,11 @@ ArrayBytesCodec, BytesBytesCodec, Codec, - CodecJSON_V2, CodecPipeline, ) from zarr.abc.numcodec import Numcodec from zarr.core.buffer import Buffer, NDBuffer - from zarr.core.common import JSON + from zarr.core.common import JSON, CodecJSON_V2 __all__ = [ "Registry", diff --git a/tests/test_abc/test_codec.py b/tests/test_abc/test_codec.py index e0f9ddb7bb..e69de29bb2 100644 --- a/tests/test_abc/test_codec.py +++ b/tests/test_abc/test_codec.py @@ -1,12 +0,0 @@ -from __future__ import annotations - -from zarr.abc.codec import _check_codecjson_v2 - - -def test_check_codecjson_v2_valid() -> None: - """ - Test that the _check_codecjson_v2 function works - """ - assert _check_codecjson_v2({"id": "gzip"}) - assert not _check_codecjson_v2({"id": 10}) - assert not _check_codecjson_v2([10, 11]) diff --git a/tests/test_array.py b/tests/test_array.py index 97aef9319b..f9995abeb7 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -46,7 +46,7 @@ from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype from zarr.core.chunk_grids import _auto_partition from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams -from zarr.core.common import JSON, ZarrFormat, ceildiv +from zarr.core.common import JSON, CodecJSON_V3, ZarrFormat, ceildiv from zarr.core.dtype import ( DateTime64, Float32, @@ -78,7 +78,6 @@ from .test_dtype.conftest import zdtype_examples if TYPE_CHECKING: - from zarr.abc.codec import CodecJSON_V3 from zarr.core.metadata.v3 import ArrayV3Metadata @@ -335,7 +334,7 @@ def test_storage_transformers(store: MemoryStore, zarr_format: ZarrFormat | str) "chunk_key_encoding": {"name": "v2", "configuration": {"separator": "/"}}, "codecs": (BytesCodec().to_dict(),), "fill_value": 0, - "storage_transformers": ({"test": "should_raise"}), + "storage_transformers": ({"name": "should_raise"},), } else: metadata_dict = { @@ -347,7 +346,7 @@ def test_storage_transformers(store: MemoryStore, zarr_format: ZarrFormat | str) "codecs": (BytesCodec().to_dict(),), "fill_value": 0, "order": "C", - "storage_transformers": ({"test": "should_raise"}), + "storage_transformers": ({"name": "should_raise"},), } if zarr_format == 3: match = "Arrays with storage transformers are not supported in zarr-python at this time." diff --git a/tests/test_dtype/test_wrapper.py b/tests/test_dtype/test_wrapper.py index cc365e86d4..2ebb435ccc 100644 --- a/tests/test_dtype/test_wrapper.py +++ b/tests/test_dtype/test_wrapper.py @@ -5,9 +5,10 @@ import pytest -from zarr.core.dtype.common import DTypeSpec_V2, DTypeSpec_V3, HasItemSize +from zarr.core.dtype.common import DTypeSpec_V2, HasItemSize if TYPE_CHECKING: + from zarr.core.common import DTypeSpec_V3 from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType diff --git a/tests/test_metadata/test_consolidated.py b/tests/test_metadata/test_consolidated.py index 9e8b763ef7..8f6de82d07 100644 --- a/tests/test_metadata/test_consolidated.py +++ b/tests/test_metadata/test_consolidated.py @@ -105,7 +105,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: "configuration": {"chunk_shape": (1, 2, 3)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "lat": ArrayV3Metadata.from_dict( @@ -115,7 +115,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: "configuration": {"chunk_shape": (1,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "lon": ArrayV3Metadata.from_dict( @@ -125,7 +125,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: "configuration": {"chunk_shape": (2,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "time": ArrayV3Metadata.from_dict( @@ -135,7 +135,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: "configuration": {"chunk_shape": (3,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "child": GroupMetadata( @@ -144,7 +144,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: metadata={ "array": ArrayV3Metadata.from_dict( { - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] "attributes": {"key": "child"}, "shape": (4, 4), "chunk_grid": { @@ -166,7 +166,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: ), "array": ArrayV3Metadata.from_dict( { - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] "attributes": {"key": "grandchild"}, "shape": (4, 4), "chunk_grid": { @@ -256,7 +256,7 @@ def test_consolidated_sync(self, memory_store: Store) -> None: "configuration": {"chunk_shape": (1, 2, 3)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "lat": ArrayV3Metadata.from_dict( @@ -266,7 +266,7 @@ def test_consolidated_sync(self, memory_store: Store) -> None: "configuration": {"chunk_shape": (1,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "lon": ArrayV3Metadata.from_dict( @@ -276,7 +276,7 @@ def test_consolidated_sync(self, memory_store: Store) -> None: "configuration": {"chunk_shape": (2,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "time": ArrayV3Metadata.from_dict( @@ -286,7 +286,7 @@ def test_consolidated_sync(self, memory_store: Store) -> None: "configuration": {"chunk_shape": (3,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), }, @@ -322,25 +322,21 @@ async def test_non_root_node(self, memory_store_with_hierarchy: Store) -> None: def test_consolidated_metadata_from_dict(self) -> None: data: dict[str, JSON] = {"must_understand": False} - # missing kind - with pytest.raises(ValueError, match="kind='None'"): - ConsolidatedMetadata.from_dict(data) - # invalid kind data["kind"] = "invalid" with pytest.raises(ValueError, match="kind='invalid'"): - ConsolidatedMetadata.from_dict(data) + ConsolidatedMetadata.from_dict(data) # type: ignore[arg-type] # missing metadata data["kind"] = "inline" with pytest.raises(TypeError, match="Unexpected type for 'metadata'"): - ConsolidatedMetadata.from_dict(data) + ConsolidatedMetadata.from_dict(data) # type: ignore[arg-type] data["kind"] = "inline" # empty is fine data["metadata"] = {} - ConsolidatedMetadata.from_dict(data) + ConsolidatedMetadata.from_dict(data) # type: ignore[arg-type] def test_flatten(self) -> None: array_metadata: dict[str, Any] = { @@ -368,7 +364,7 @@ def test_flatten(self) -> None: "configuration": {"chunk_shape": (1, 2, 3)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "lat": ArrayV3Metadata.from_dict( @@ -378,7 +374,7 @@ def test_flatten(self) -> None: "configuration": {"chunk_shape": (1,)}, "name": "regular", }, - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] } ), "child": GroupMetadata( @@ -387,7 +383,7 @@ def test_flatten(self) -> None: metadata={ "array": ArrayV3Metadata.from_dict( { - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] "attributes": {"key": "child"}, "shape": (4, 4), "chunk_grid": { @@ -402,7 +398,7 @@ def test_flatten(self) -> None: metadata={ "array": ArrayV3Metadata.from_dict( { - **array_metadata, + **array_metadata, # type: ignore[typeddict-item] "attributes": {"key": "grandchild"}, "shape": (4, 4), "chunk_grid": { @@ -450,7 +446,7 @@ def test_invalid_metadata_raises(self) -> None: } with pytest.raises(TypeError, match="key='foo', type='list'"): - ConsolidatedMetadata.from_dict(payload) + ConsolidatedMetadata.from_dict(payload) # type: ignore[arg-type] def test_to_dict_empty(self) -> None: meta = ConsolidatedMetadata( diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 1405bf533b..3ab1cc4d39 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -2,39 +2,31 @@ import json import re -from typing import TYPE_CHECKING, Literal +from collections.abc import Mapping +from typing import TYPE_CHECKING, Literal, cast import numpy as np import pytest -from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import default_buffer_prototype -from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding +from zarr.core.common import NamedConfig from zarr.core.config import config from zarr.core.dtype import get_data_type_from_native_dtype from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.core.dtype.npy.time import DateTime64 -from zarr.core.group import GroupMetadata, parse_node_type +from zarr.core.group import GroupMetadata from zarr.core.metadata.v3 import ( ArrayV3Metadata, parse_codecs, - parse_dimension_names, - parse_zarr_format, ) -from zarr.errors import MetadataValidationError, NodeTypeValidationError, UnknownCodecError +from zarr.errors import UnknownCodecError if TYPE_CHECKING: - from collections.abc import Sequence from typing import Any - from zarr.abc.codec import Codec - from zarr.core.common import JSON + from zarr.core.common import JSON, ArrayMetadataJSON_V3 -from zarr.core.metadata.v3 import ( - parse_node_type_array, -) - bool_dtypes = ("bool",) int_dtypes = ( @@ -71,57 +63,6 @@ ) -@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"]) -def test_parse_zarr_format_invalid(data: Any) -> None: - with pytest.raises( - MetadataValidationError, - match=f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'.", - ): - parse_zarr_format(data) - - -def test_parse_zarr_format_valid() -> None: - assert parse_zarr_format(3) == 3 - - -def test_parse_node_type_valid() -> None: - assert parse_node_type("array") == "array" - assert parse_node_type("group") == "group" - - -@pytest.mark.parametrize("node_type", [None, 2, "other"]) -def test_parse_node_type_invalid(node_type: Any) -> None: - with pytest.raises( - MetadataValidationError, - match=f"Invalid value for 'node_type'. Expected 'array' or 'group'. Got '{node_type}'.", - ): - parse_node_type(node_type) - - -@pytest.mark.parametrize("data", [None, "group"]) -def test_parse_node_type_array_invalid(data: Any) -> None: - with pytest.raises( - NodeTypeValidationError, - match=f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'.", - ): - parse_node_type_array(data) - - -def test_parse_node_typev_array_alid() -> None: - assert parse_node_type_array("array") == "array" - - -@pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}]) -def parse_dimension_names_invalid(data: Any) -> None: - with pytest.raises(TypeError, match="Expected either None or iterable of str,"): - parse_dimension_names(data) - - -@pytest.mark.parametrize("data", [None, ("a", "b", "c"), ["a", "a", "a"]]) -def parse_dimension_names_valid(data: Sequence[str] | None) -> None: - assert parse_dimension_names(data) == data - - @pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1]]) @pytest.mark.parametrize("dtype_str", [*complex_dtypes]) def test_jsonify_fill_value_complex(fill_value: Any, dtype_str: str) -> None: @@ -170,83 +111,99 @@ def test_parse_fill_value_invalid_type_sequence(fill_value: Any, dtype_str: str) dtype_instance.from_json_scalar(fill_value, zarr_format=3) -@pytest.mark.parametrize("chunk_grid", ["regular"]) -@pytest.mark.parametrize("attributes", [None, {"foo": "bar"}]) -@pytest.mark.parametrize("codecs", [[BytesCodec(endian=None)]]) +@pytest.mark.parametrize( + "chunk_grid", [{"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}] +) +@pytest.mark.parametrize("codecs", [({"name": "bytes"},)]) @pytest.mark.parametrize("fill_value", [0, 1]) -@pytest.mark.parametrize("chunk_key_encoding", ["v2", "default"]) -@pytest.mark.parametrize("dimension_separator", [".", "/", None]) -@pytest.mark.parametrize("dimension_names", ["nones", "strings", "missing"]) -@pytest.mark.parametrize("storage_transformers", [None, ()]) +@pytest.mark.parametrize("data_type", ["int8", "uint8"]) +@pytest.mark.parametrize( + "chunk_key_encoding", + [ + {"name": "v2", "configuration": {"separator": "."}}, + {"name": "v2", "configuration": {"separator": "/"}}, + {"name": "v2"}, + {"name": "default", "configuration": {"separator": "."}}, + {"name": "default", "configuration": {"separator": "/"}}, + {"name": "default"}, + ], +) +@pytest.mark.parametrize("attributes", ["unset", {"foo": "bar"}]) +@pytest.mark.parametrize("dimension_names", [(None, None, None), ("a", "b", None), "unset"]) +@pytest.mark.parametrize("storage_transformers", [(), "unset"]) def test_metadata_to_dict( - chunk_grid: str, - codecs: list[Codec], + chunk_grid: NamedConfig[str, Mapping[str, object]], + codecs: tuple[NamedConfig[str, Mapping[str, object]]], + data_type: str, fill_value: Any, - chunk_key_encoding: Literal["v2", "default"], - dimension_separator: Literal[".", "/"] | None, - dimension_names: Literal["nones", "strings", "missing"], - attributes: dict[str, Any] | None, - storage_transformers: tuple[dict[str, JSON]] | None, + chunk_key_encoding: NamedConfig[str, Mapping[str, object]], + dimension_names: tuple[str | None, ...] | Literal["unset"], + attributes: Mapping[str, Any] | Literal["unset"], + storage_transformers: tuple[dict[str, JSON]] | Literal["unset"], ) -> None: shape = (1, 2, 3) - data_type_str = "uint8" - if chunk_grid == "regular": - cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}} - - cke: dict[str, Any] - cke_name_dict = {"name": chunk_key_encoding} - if dimension_separator is not None: - cke = cke_name_dict | {"configuration": {"separator": dimension_separator}} - else: - cke = cke_name_dict - dnames: tuple[str | None, ...] | None - - if dimension_names == "strings": - dnames = tuple(map(str, range(len(shape)))) - elif dimension_names == "missing": - dnames = None - elif dimension_names == "nones": - dnames = (None,) * len(shape) - - metadata_dict = { - "zarr_format": 3, - "node_type": "array", - "shape": shape, - "chunk_grid": cgrid, - "data_type": data_type_str, - "chunk_key_encoding": cke, - "codecs": tuple(c.to_dict() for c in codecs), - "fill_value": fill_value, - "storage_transformers": storage_transformers, - } - if attributes is not None: - metadata_dict["attributes"] = attributes - if dnames is not None: - metadata_dict["dimension_names"] = dnames + # These are the fields in the array metadata document that are optional + not_required: dict[str, object] = {} - metadata = ArrayV3Metadata.from_dict(metadata_dict) - observed = metadata.to_dict() - expected = metadata_dict.copy() + if dimension_names != "unset": + not_required["dimension_names"] = dimension_names - # if unset or None or (), storage_transformers gets normalized to () - assert observed["storage_transformers"] == () - observed.pop("storage_transformers") - expected.pop("storage_transformers") + if storage_transformers != "unset": + not_required["storage_transformers"] = storage_transformers - if attributes is None: - assert observed["attributes"] == {} - observed.pop("attributes") + if attributes != "unset": + not_required["attributes"] = attributes - if dimension_separator is None: - if chunk_key_encoding == "default": - expected_cke_dict = DefaultChunkKeyEncoding(separator="/").to_dict() + source_dict: ArrayMetadataJSON_V3 = { + "zarr_format": 3, + "node_type": "array", + "shape": shape, + "chunk_grid": chunk_grid, + "data_type": data_type, + "chunk_key_encoding": chunk_key_encoding, + "codecs": codecs, + "fill_value": fill_value, + } | not_required # type: ignore[assignment] + + metadata = ArrayV3Metadata.from_dict(source_dict) + parsed_dict = metadata.to_dict() + + for k, v in parsed_dict.items(): + if k in source_dict: + if k == "chunk_key_encoding": + v = cast(NamedConfig[str, Mapping[str, object]], v) + assert v["name"] == chunk_key_encoding["name"] + if chunk_key_encoding["name"] == "v2": + if "configuration" in chunk_key_encoding: + if "separator" in chunk_key_encoding["configuration"]: + assert "configuration" in v + assert ( + v["configuration"]["separator"] + == chunk_key_encoding["configuration"]["separator"] + ) + else: + assert v["configuration"]["separator"] == "." + elif chunk_key_encoding["name"] == "default": + if "configuration" in chunk_key_encoding: + if "separator" in chunk_key_encoding["configuration"]: + assert "configuration" in v + assert ( + v["configuration"]["separator"] + == chunk_key_encoding["configuration"]["separator"] + ) + else: + assert "configuration" in v + assert v["configuration"]["separator"] == "/" + else: + assert source_dict[k] == v # type: ignore[literal-required] else: - expected_cke_dict = V2ChunkKeyEncoding(separator=".").to_dict() - assert observed["chunk_key_encoding"] == expected_cke_dict - observed.pop("chunk_key_encoding") - expected.pop("chunk_key_encoding") - assert observed == expected + if k == "attributes": + assert v == {} + elif k == "storage_transformers": + assert v == () + else: + assert v is None @pytest.mark.parametrize("indent", [2, 4, None]) @@ -261,14 +218,14 @@ def test_json_indent(indent: int) -> None: @pytest.mark.parametrize("precision", ["ns", "D"]) async def test_datetime_metadata(fill_value: int, precision: Literal["ns", "D"]) -> None: dtype = DateTime64(unit=precision) - metadata_dict: dict[str, Any] = { + metadata_dict: ArrayMetadataJSON_V3 = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": dtype.to_json(zarr_format=3), - "chunk_key_encoding": {"name": "default", "separator": "."}, - "codecs": (BytesCodec(),), + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "."}}, + "codecs": ({"name": "bytes"},), "fill_value": dtype.to_json_scalar( dtype.to_native_dtype().type(fill_value, dtype.unit), zarr_format=3 ), @@ -285,13 +242,13 @@ async def test_datetime_metadata(fill_value: int, precision: Literal["ns", "D"]) ("data_type", "fill_value"), [("uint8", {}), ("int32", [0, 1]), ("float32", "foo")] ) async def test_invalid_fill_value_raises(data_type: str, fill_value: float) -> None: - metadata_dict: dict[str, Any] = { + metadata_dict: ArrayMetadataJSON_V3 = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": data_type, - "chunk_key_encoding": {"name": "default", "separator": "."}, + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "."}}, "codecs": ({"name": "bytes"},), "fill_value": fill_value, # this is not a valid fill value for uint8 } @@ -302,13 +259,13 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: float) -> N @pytest.mark.parametrize("fill_value", [("NaN"), "Infinity", "-Infinity"]) async def test_special_float_fill_values(fill_value: str) -> None: - metadata_dict: dict[str, Any] = { + metadata_dict: ArrayMetadataJSON_V3 = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": "float64", - "chunk_key_encoding": {"name": "default", "separator": "."}, + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "."}}, "codecs": [{"name": "bytes"}], "fill_value": fill_value, # this is not a valid fill value for uint8 } diff --git a/tests/test_type_check.py b/tests/test_type_check.py new file mode 100644 index 0000000000..c07591ccee --- /dev/null +++ b/tests/test_type_check.py @@ -0,0 +1,602 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ForwardRef, Literal, NotRequired, TypeVar + +import pytest +from typing_extensions import ReadOnly, TypedDict + +from zarr.core.common import ArrayMetadataJSON_V3, DTypeSpec_V3, NamedConfig, StructuredName_V2 +from zarr.core.dtype.common import DTypeConfig_V2, DTypeSpec_V2 +from zarr.core.dtype.npy.structured import StructuredJSON_V2 +from zarr.core.dtype.npy.time import TimeConfig +from zarr.core.type_check import ( + TypeCheckResult, + _get_typeddict_metadata, + _resolve_type, + _type_name, + check_type, + check_typeddict, + ensure_type, + guard_type, +) + + +# --- Sample TypedDicts for testing --- +class Address(TypedDict): + street: str + zipcode: int + + +class User(TypedDict): + id: int + name: str + address: Address + tags: list[str] + + +class PartialUser(TypedDict, total=False): + id: int + name: str + + +@pytest.mark.parametrize( + ("inliers", "outliers", "typ"), + [ + ((1, 2, 3), (), Any), + ((True, False), (1, "True", [True]), bool), + (("a", "1"), (1, True, ["a"]), str), + ((1.0, 2.0), (1, True, ["a"]), float), + ((1, 2), (1.0, True, ["a"]), int), + ((1, 2, None), (3, 4, "a"), Literal[1, 2, None]), + ((1, 2, None), ("a", 1.2), int | None), + ((("a", 1), ("hello", 2)), (True, 1, ("a", 1, 1), ()), tuple[str, int]), + ( + (("a",), ("a", 1), ("hello", 2, 3), ()), + ((True, 1), ("a", 1, 1.0)), + tuple[str | int, ...], + ), + ((["a", "b"], ["x"], []), (["a", 1], "oops", 1), list[str]), + (({"a": 1, "b": 2}, {"x": 10}, {}), ({"a": "oops"}, [("a", 1)]), dict[str, int]), + (({"a": 1, "b": 2}, {10: 10}, {}), (), dict[Any, Any]), + (({"a": 1, "b": 2}, {"x": 10}, {}), ({"a": "oops"}, [("a", 1)]), Mapping[str, int]), + ], +) +def test_inliers_outliers(inliers: tuple[Any, ...], outliers: tuple[Any, ...], typ: type) -> None: + """ + Given a set of inliers and outliers for a type, test that check_type correctly + identifies valid and invalid cases. This test is used for types that can be written down compactly. + """ + assert all(check_type(val, typ).success for val in inliers) + assert all(not check_type(val, typ).success for val in outliers) + + +def test_typeddict_valid() -> None: + """ + Test that check_type correctly validates a complex nested TypedDict structure. + + Verifies that TypedDict validation works with nested structures, + testing a User TypedDict containing an Address TypedDict and list fields. + All field types must match their annotations exactly. + """ + user: User = { + "id": 1, + "name": "Alice", + "address": {"street": "Main St", "zipcode": 12345}, + "tags": ["x", "y"], + } + result = check_type(user, User) + assert result.success + + +def test_typeddict_invalid() -> None: + """ + Test that check_type correctly rejects a TypedDict with invalid field types. + + Verifies that TypedDict validation fails when nested fields have wrong types. + Tests a User with an Address where zipcode is "oops" (string) instead of + the expected int type. + """ + bad_user = { + "id": 1, + "name": "Alice", + "address": {"street": "Main St", "zipcode": "oops"}, + "tags": ["x", "y"], + } + result = check_type(bad_user, User) + assert not result.success + # assert "expected int but got str" in "".join(result.errors) + + +def test_typeddict_fail_missing_required() -> None: + """ + Test that check_type correctly rejects a TypedDict missing required fields. + + Verifies that TypedDict validation enforces required fields, failing when + a required key is missing. Tests an Address TypedDict missing the required + 'zipcode' field. + """ + bad_user = { + "id": 1, + "name": "Alice", + "address": {"street": "Main St"}, # missing zipcode + "tags": ["x"], + } + result = check_type(bad_user, User) + assert not result.success + assert "missing required key 'zipcode'" in "".join(result.errors) + + +def test_typeddict_partial_total_false_pass() -> None: + """ + Test that check_type correctly handles TypedDict with total=False allowing empty dicts. + + Verifies that TypedDict with total=False makes all fields optional, + allowing an empty dictionary {} to pass validation against PartialUser + which has all optional fields. + """ + result = check_type({}, PartialUser) + assert result.success + + +def test_typeddict_partial_total_false_fail() -> None: + """ + Test that check_type rejects TypedDict with total=False when present fields have wrong types. + + Verifies that even with total=False making fields optional, any present + fields must still match their type annotations. Tests PartialUser with + {"id": "wrong-type"} where id should be int. + """ + bad = {"id": "wrong-type"} + result = check_type(bad, PartialUser) + assert not result.success + # assert f"expected {int} but got 'wrong-type' with type {str}" in result.errors + + +@pytest.mark.parametrize("data", [10, {"blame": "foo", "configuration": {"foo": "bar"}}]) +def test_typeddict_dtype_spec_invalid(data: DTypeSpec_V3) -> None: + """ + Test that check_type correctly rejects invalid DTypeSpec_V3 structures. + + Verifies that DTypeSpec_V3 validation fails for incorrect formats. + Tests with an integer (10) and a dict with wrong field names + ("blame" instead of "name"), both should be rejected. + """ + result = check_type(data, DTypeSpec_V3) + assert not result.success + + +@pytest.mark.parametrize("data", ["foo", {"name": "foo", "configuration": {"foo": "bar"}}]) +def test_typeddict_dtype_spec_valid(data: DTypeSpec_V3) -> None: + """ + Test that check_type correctly accepts valid DTypeSpec_V3 structures. + + Verifies that DTypeSpec_V3 validation passes for correct formats. + Tests with both a simple string ("foo") and a proper dict structure + with "name" and "configuration" fields. + """ + x: DTypeSpec_V3 = "foo" + result = check_type(x, DTypeSpec_V3) + assert result.success + + +class InheritedTD(DTypeConfig_V2[str, None]): ... + + +@pytest.mark.parametrize("typ", [DTypeSpec_V2, DTypeConfig_V2[str, None], InheritedTD]) +def test_typeddict_dtype_spec_v2_valid(typ: type) -> None: + """ + Test that check_type correctly validates various DTypeSpec_V2 and DTypeConfig_V2 types. + + Verifies that version 2 dtype specifications work correctly with different + generic parameterizations. Tests DTypeSpec_V2, generic DTypeConfig_V2[str, None], + and inherited TypedDict classes. + """ + result = check_type({"name": "gzip", "object_codec_id": None}, typ) + assert result.success + + +@pytest.mark.parametrize("typ", [DTypeConfig_V2[StructuredName_V2, None], StructuredJSON_V2]) +def test_typeddict_recursive(typ: type) -> None: + """ + Test that check_type correctly handles recursive/nested TypedDict structures. + + Verifies that complex nested structures like structured dtypes work properly. + Tests with DTypeConfig_V2 containing StructuredName_V2 and StructuredJSON_V2 + which contain nested field definitions. + """ + result = check_type( + {"name": [["field1", ">i4"], ["field2", ">f8"]], "object_codec_id": None}, typ + ) + assert result.success + + +def test_datetime_valid() -> None: + """ + Test that check_type correctly validates datetime configuration structures. + + Verifies that complex NamedConfig structures work with specific literal types. + Tests DateTime64JSON_V3 which is a NamedConfig with numpy.datetime64 literal + name and TimeConfig configuration. + """ + DateTime64JSON_V3 = NamedConfig[Literal["numpy.datetime64"], TimeConfig] + data: DateTime64JSON_V3 = { + "name": "numpy.datetime64", + "configuration": {"unit": "ns", "scale_factor": 10}, + } + result = check_type(data, DateTime64JSON_V3) + assert result.success + + +@pytest.mark.parametrize( + "optionals", + [{}, {"attributes": {}}, {"storage_transformers": ()}, {"dimension_names": ("a", "b")}], +) +def test_zarr_v2_metadata(optionals: dict[str, object]) -> None: + """ + Test that check_type correctly validates ArrayMetadataJSON_V3 with optional fields. + + Verifies that Zarr v3 array metadata validation works with different combinations + of optional fields. Tests the base required fields plus various optional field + combinations like attributes, storage_transformers, and dimension_names. + """ + meta: ArrayMetadataJSON_V3 = { + "zarr_format": 3, + "node_type": "array", + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "."}}, + "shape": (10, 10), + "fill_value": 0, + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (5, 5)}}, + "codecs": ("bytes",), + "attributes": {"a": 1, "b": 2}, + "data_type": "uint8", + } | optionals # type: ignore[assignment] + result = check_type(meta, ArrayMetadataJSON_V3) + assert result.success + + +def test_external_generic_typeddict() -> None: + """ + Test that check_type correctly validates external generic TypedDict structures. + + Verifies that generic TypedDict classes from external modules work properly. + Tests NamedConfig with specific literal type and Mapping configuration, + ensuring generic type parameters are resolved correctly. + """ + x: NamedConfig[Literal["default"], Mapping[str, object]] = { + "name": "default", + "configuration": {"foo": "bar"}, + } + result = check_type(x, NamedConfig[Literal["default"], Mapping[str, object]]) + assert result.success + + +def test_typeddict_extra_keys_allowed() -> None: + """ + Test that check_type allows extra keys in TypedDict structures. + + Verifies that TypedDict validation is flexible about additional keys + not defined in the TypedDict schema. Tests a TypedDict with field 'a' + but provides both 'a' and 'b', ensuring 'b' is allowed. + """ + + class X(TypedDict): + a: int + + b: X = {"a": 1, "b": 2} # type: ignore[typeddict-unknown-key] + result = check_type(b, X) + assert result.success + + +def test_typeddict_readonly_notrequired() -> None: + """ + Test that check_type correctly handles ReadOnly and NotRequired annotations in TypedDict. + + Verifies that complex type annotations like ReadOnly[NotRequired[int]] work properly. + Tests various combinations and nesting of ReadOnly and NotRequired annotations, + ensuring all optional fields can be omitted. + """ + + class X(TypedDict): + a: ReadOnly[NotRequired[int]] + b: NotRequired[ReadOnly[int]] + c: Annotated[ReadOnly[NotRequired[int]], 10] + d: int + + b: X = {"d": 1} + result = check_type(b, X) + assert result.success + + +def test_ensure_type_valid() -> None: + """ + Test that ensure_type returns the input value when type validation succeeds. + + Verifies that ensure_type acts as a type-safe identity function, + returning the original value when it matches the expected type. + Tests with integer 42 against int type. + """ + result = ensure_type(42, int) + assert result == 42 + + +def test_ensure_type_invalid() -> None: + """ + Test that ensure_type raises TypeError when type validation fails. + + Verifies that ensure_type throws an appropriate TypeError with descriptive + message when the input value doesn't match the expected type. + Tests with string "hello" against int type. + """ + with pytest.raises(TypeError, match="Expected an instance of but got 'hello'"): + ensure_type("hello", int) + + +def test_guard_type_valid() -> None: + """ + Test that guard_type returns True when type validation succeeds. + + Verifies that guard_type acts as a boolean type guard function, + returning True when the input matches the expected type for use + in conditional type narrowing. Tests with integer 42 against int type. + """ + assert guard_type(42, int) is True + + +def test_guard_type_invalid() -> None: + """ + Test that guard_type returns False when type validation fails. + + Verifies that guard_type correctly identifies type mismatches by returning + False, allowing for safe type narrowing in conditional blocks. + Tests with string "hello" against int type. + """ + assert guard_type("hello", int) is False + + +def test_check_type_none_with_none_type() -> None: + """ + Test that check_type correctly validates None against type(None) annotation. + + Verifies that explicit None type checking works using type(None) syntax + as opposed to just None. Tests both valid None value and invalid + string "not none" against type(None). + """ + result = check_type(None, type(None)) + assert result.success + + result = check_type("not none", type(None)) + assert not result.success + + +def test_check_type_fallback_isinstance() -> None: + """ + Test that check_type falls back to isinstance() for custom class types. + + Verifies that when no specific type checking logic applies, the function + falls back to using isinstance() for validation. Tests with a custom + class to ensure both valid instances and invalid values are handled correctly. + """ + + class CustomClass: + pass + + obj = CustomClass() + result = check_type(obj, CustomClass) + assert result.success + + result = check_type("not custom", CustomClass) + assert not result.success + + +def test_check_type_fallback_type_error() -> None: + """ + Test that check_type handles TypeError in fallback isinstance() gracefully. + + Verifies that when isinstance() raises a TypeError (e.g., with ForwardRef), + the function catches the exception and returns an appropriate error message. + Tests with a problematic ForwardRef type. + """ + # Create a problematic type that can't be used with isinstance + problematic_type = ForwardRef("NonExistentType") + result = check_type("anything", problematic_type) + assert not result.success + assert "cannot be checked against" in result.errors[0] + + +def test_sequence_type_string_bytes_excluded() -> None: + """ + Test that check_type excludes strings and bytes from sequence type validation. + + Verifies that str and bytes are not treated as sequences even though they + technically implement the sequence protocol. This prevents strings from + being validated as list[str] where each character would be checked. + """ + result = check_type("string", list[str]) + assert not result.success + assert "expected sequence" in result.errors[0] + + result = check_type(b"bytes", list[str]) + assert not result.success + assert "expected sequence" in result.errors[0] + + +def test_typeddict_non_dict() -> None: + """ + Test that check_type correctly rejects non-dict objects for TypedDict validation. + + Verifies that TypedDict validation first ensures the input is a dictionary + before checking field types. Tests list against User TypedDict, + which should fail at the dict requirement level. + """ + result = check_type([], User) + assert not result.success + assert "expected dict for TypedDict" in result.errors[0] + + +def test_typeddict_type_parameter_mismatch() -> None: + """ + Test that check_type correctly detects type parameter count mismatches in generic TypedDict. + + Verifies that generic TypedDict validation enforces correct parameterization. + Tests a generic TypedDict with TypeVar T, ensuring that type parameter + counting validation works properly and reports mismatches. + """ + T = TypeVar("T") + + class GenericTD(TypedDict): + value: T # type: ignore[valid-type] + + # This will trigger a type parameter count mismatch because + # Generic TypedDict validation is strict about parameter counts + result = check_type({"value": 42}, GenericTD[int]) # type: ignore[misc] + # This actually fails due to type parameter mismatch validation + assert not result.success + assert "type parameter count mismatch" in result.errors[0] + + +def test_complex_nested_unions() -> None: + """ + Test that check_type correctly validates complex nested structures with union types. + + Verifies that deeply nested type validation works with combinations of + dictionaries and union types. Tests dict[str, int | str | None] with + valid data containing different union member types and invalid data. + """ + ComplexType = dict[str, int | str | None] + + test_data: dict[str, Any] = {"int_val": 42, "str_val": "hello", "none_val": None} + + result = check_type(test_data, ComplexType) + assert result.success + + bad_data: dict[str, list[Any]] = { + "bad_val": [] # list is not in the union + } + + result = check_type(bad_data, ComplexType) + assert not result.success + + +def test_type_check_result_dataclass() -> None: + """ + Test that TypeCheckResult dataclass works correctly as a return type. + + Verifies that the TypeCheckResult dataclass properly stores success status + and error messages. Tests both successful validation (empty errors) and + failed validation (with multiple errors). + """ + result = TypeCheckResult(True, []) + assert result.success + assert result.errors == [] + + result = TypeCheckResult(False, ["error1", "error2"]) + assert not result.success + assert len(result.errors) == 2 + + +def test_sequence_with_collections_abc() -> None: + """ + Test that check_type correctly validates custom sequence implementations. + + Verifies that sequence type checking works with custom classes that + implement collections.abc.Sequence protocol. Tests a CustomSequence + class against list[int] to ensure protocol-based validation works. + """ + + # Test with a custom sequence + class CustomSequence(Sequence[Any]): + def __init__(self, items: Sequence[Any]) -> None: + self._items = items + + def __getitem__(self, index: Any) -> Any: + return self._items[index] + + def __len__(self) -> int: + return len(self._items) + + custom_seq = CustomSequence([1, 2, 3]) + result = check_type(custom_seq, list[int]) + assert result.success + + +@pytest.mark.parametrize( + ("typ", "expected"), + [ + (str, "str"), + (list, "list"), + (list[int], "list[int]"), + (str | int, "str | int"), + ], +) +def test_type_name(typ: Any, expected: str) -> None: + assert _type_name(typ) == expected + + +def test_typevar_self_reference_edge_case() -> None: + """ + Test TypeVar that maps to itself in type resolution (line 114). + + Tests the edge case where a TypeVar in the type_map resolves to itself, + triggering the self-reference detection in _resolve_type_impl. + This covers the rarely hit line 114. + """ + T = TypeVar("T") + # Create a type_map where T maps to itself + type_map = {T: T} + + # This should trigger the self-reference detection + result = _resolve_type(T, type_map=type_map) + assert result == T # Should return the original TypeVar + + +def test_non_typeddict_fallback_error() -> None: + """ + Test error when non-TypedDict is passed to check_typeddict (line 259). + + Tests the fallback error case when _get_typeddict_metadata returns None, + meaning the type is not actually a TypedDict. + """ + + # Pass a regular class that's not a TypedDict + class NotATypedDict: + pass + + result = check_typeddict({"key": "value"}, NotATypedDict, "test_path") + assert not result.success + assert "expected a TypedDict but got" in result.errors[0] + + +def test_get_typeddict_metadata_fallback() -> None: + """ + + Tests the fallback case where _get_typeddict_metadata cannot extract + valid metadata from the provided type. + """ + + # Test with a type that's not a TypedDict at all + result = _get_typeddict_metadata(int) + assert result == (None, None, None, None) + + # Test with a complex type that looks like it might be TypedDict but isn't + result = _get_typeddict_metadata(dict[str, int]) + assert result == (None, None, None, None) + + +@pytest.mark.parametrize(("typ_str", "typ_expected"), [("int", int), ("str", str), ("list", list)]) +def test_complex_forwardref_scenarios(typ_str: str, typ_expected: type) -> None: + """ + Test additional ForwardRef scenarios to ensure coverage. + + Tests various ForwardRef evaluation scenarios including edge cases + that might not be covered by the basic test. + """ + # String type annotations aren't handled the same way as ForwardRef + # in the current implementation. The ForwardRef evaluation happens + # in internal type resolution, not in the main check_type path. + + # Instead, let's test a scenario that would use ForwardRef internally + + # Test string that would need evaluation in type context + result = _resolve_type(typ_str, globalns=globals()) + assert result is typ_expected