Skip to content

Commit a2da99a

Browse files
committed
refactor dtypewrapper -> zdtype
1 parent e855e54 commit a2da99a

File tree

27 files changed

+1312
-631
lines changed

27 files changed

+1312
-631
lines changed

src/zarr/abc/codec.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
4+
from typing import TYPE_CHECKING, Generic, TypeVar
55

66
from zarr.abc.metadata import Metadata
77
from zarr.core.buffer import Buffer, NDBuffer
@@ -12,11 +12,10 @@
1212
from collections.abc import Awaitable, Callable, Iterable
1313
from typing import Self
1414

15-
import numpy as np
16-
1715
from zarr.abc.store import ByteGetter, ByteSetter
1816
from zarr.core.array_spec import ArraySpec
1917
from zarr.core.chunk_grids import ChunkGrid
18+
from zarr.core.dtype.wrapper import ZDType, _BaseDType, _BaseScalar
2019
from zarr.core.indexing import SelectorTuple
2120

2221
__all__ = [
@@ -93,7 +92,13 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
9392
"""
9493
return self
9594

96-
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
95+
def validate(
96+
self,
97+
*,
98+
shape: ChunkCoords,
99+
dtype: ZDType[_BaseDType, _BaseScalar],
100+
chunk_grid: ChunkGrid,
101+
) -> None:
97102
"""Validates that the codec configuration is compatible with the array metadata.
98103
Raises errors when the codec configuration is not compatible.
99104
@@ -285,7 +290,9 @@ def supports_partial_decode(self) -> bool: ...
285290
def supports_partial_encode(self) -> bool: ...
286291

287292
@abstractmethod
288-
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
293+
def validate(
294+
self, *, shape: ChunkCoords, dtype: ZDType[_BaseDType, _BaseScalar], chunk_grid: ChunkGrid
295+
) -> None:
289296
"""Validates that all codec configurations are compatible with the array metadata.
290297
Raises errors when a codec configuration is not compatible.
291298

src/zarr/api/asynchronous.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_warn_order_kwarg,
2929
_warn_write_empty_chunks_kwarg,
3030
)
31-
from zarr.core.dtype import get_data_type_from_numpy
31+
from zarr.core.dtype import get_data_type_from_native_dtype
3232
from zarr.core.group import (
3333
AsyncGroup,
3434
ConsolidatedMetadata,
@@ -433,7 +433,7 @@ async def save_array(
433433
shape = arr.shape
434434
chunks = getattr(arr, "chunks", None) # for array-likes with chunks attribute
435435
overwrite = kwargs.pop("overwrite", None) or _infer_overwrite(mode)
436-
zarr_dtype = get_data_type_from_numpy(arr.dtype)
436+
zarr_dtype = get_data_type_from_native_dtype(arr.dtype)
437437
new = await AsyncArray._create(
438438
store_path,
439439
zarr_format=zarr_format,
@@ -984,7 +984,7 @@ async def create(
984984
_handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
985985
or _default_zarr_format()
986986
)
987-
dtype_wrapped = get_data_type_from_numpy(dtype)
987+
dtype_wrapped = get_data_type_from_native_dtype(dtype)
988988
if zarr_format == 2:
989989
if chunks is None:
990990
chunks = shape

src/zarr/codecs/bytes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import sys
44
from dataclasses import dataclass, replace
55
from enum import Enum
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, cast
77

88
import numpy as np
99

1010
from zarr.abc.codec import ArrayBytesCodec
1111
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
1212
from zarr.core.common import JSON, parse_enum, parse_named_configuration
13-
from zarr.core.dtype.common import endianness_to_numpy_str
13+
from zarr.core.dtype._numpy import endianness_to_numpy_str
1414
from zarr.registry import register_codec
1515

1616
if TYPE_CHECKING:
1717
from typing import Self
1818

1919
from zarr.core.array_spec import ArraySpec
20+
from zarr.core.dtype.common import Endianness
2021

2122

2223
class Endian(Enum):
@@ -73,7 +74,9 @@ async def _decode_single(
7374
) -> NDBuffer:
7475
assert isinstance(chunk_bytes, Buffer)
7576
# TODO: remove endianness enum in favor of literal union
76-
endian_str = self.endian.value if self.endian is not None else None
77+
endian_str = cast(
78+
"Endianness | None", self.endian.value if self.endian is not None else None
79+
)
7780
dtype = chunk_spec.dtype.to_dtype().newbyteorder(endianness_to_numpy_str(endian_str))
7881

7982
as_array_like = chunk_bytes.as_array_like()

src/zarr/codecs/sharding.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
parse_shapelike,
4444
product,
4545
)
46+
from zarr.core.dtype._numpy import UInt64
4647
from zarr.core.indexing import (
4748
BasicIndexer,
4849
SelectorTuple,
@@ -58,7 +59,7 @@
5859
from typing import Self
5960

6061
from zarr.core.common import JSON
61-
from zarr.core.dtype.wrapper import DTypeWrapper
62+
from zarr.core.dtype.wrapper import ZDType, _BaseDType, _BaseScalar
6263

6364
MAX_UINT_64 = 2**64 - 1
6465
ShardMapping = Mapping[ChunkCoords, Buffer]
@@ -405,7 +406,11 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
405406
return self
406407

407408
def validate(
408-
self, *, shape: ChunkCoords, dtype: DTypeWrapper[Any, Any], chunk_grid: ChunkGrid
409+
self,
410+
*,
411+
shape: ChunkCoords,
412+
dtype: ZDType[_BaseDType, _BaseScalar],
413+
chunk_grid: ChunkGrid,
409414
) -> None:
410415
if len(self.chunk_shape) != len(shape):
411416
raise ValueError(
@@ -443,7 +448,10 @@ async def _decode_single(
443448

444449
# setup output array
445450
out = chunk_spec.prototype.nd_buffer.create(
446-
shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
451+
shape=shard_shape,
452+
dtype=shard_spec.dtype.to_dtype(),
453+
order=shard_spec.order,
454+
fill_value=0,
447455
)
448456
shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard)
449457

@@ -685,7 +693,7 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
685693
def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec:
686694
return ArraySpec(
687695
shape=chunks_per_shard + (2,),
688-
dtype=np.dtype("<u8"),
696+
dtype=UInt64(endianness="little"),
689697
fill_value=MAX_UINT_64,
690698
config=ArrayConfig(
691699
order="C", write_empty_chunks=False

src/zarr/codecs/transpose.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from zarr.registry import register_codec
1313

1414
if TYPE_CHECKING:
15-
from typing import Any, Self
15+
from typing import Self
1616

1717
from zarr.core.buffer import NDBuffer
1818
from zarr.core.chunk_grids import ChunkGrid
19+
from zarr.core.dtype.wrapper import ZDType, _BaseDType, _BaseScalar
1920

2021

2122
def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]:
@@ -45,7 +46,12 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
4546
def to_dict(self) -> dict[str, JSON]:
4647
return {"name": "transpose", "configuration": {"order": tuple(self.order)}}
4748

48-
def validate(self, shape: tuple[int, ...], dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
49+
def validate(
50+
self,
51+
shape: tuple[int, ...],
52+
dtype: ZDType[_BaseDType, _BaseScalar],
53+
chunk_grid: ChunkGrid,
54+
) -> None:
4955
if len(self.order) != len(shape):
5056
raise ValueError(
5157
f"The `order` tuple needs have as many entries as there are dimensions in the array. Got {self.order}."

src/zarr/core/_info.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import textwrap
3-
from typing import Any, Literal
5+
from typing import TYPE_CHECKING, Any, Literal
46

5-
import numcodecs.abc
6-
import numpy as np
7+
if TYPE_CHECKING:
8+
import numcodecs.abc
9+
import numpy as np
710

8-
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
9-
from zarr.core.common import ZarrFormat
10-
from zarr.core.dtype.wrapper import DTypeWrapper
11+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
12+
from zarr.core.common import ZarrFormat
13+
from zarr.core.dtype.wrapper import ZDType, _BaseDType, _BaseScalar
1114

1215

1316
@dataclasses.dataclass(kw_only=True)
@@ -78,7 +81,7 @@ class ArrayInfo:
7881

7982
_type: Literal["Array"] = "Array"
8083
_zarr_format: ZarrFormat
81-
_data_type: np.dtype[Any] | DTypeWrapper
84+
_data_type: np.dtype[Any] | ZDType[_BaseDType, _BaseScalar]
8285
_shape: tuple[int, ...]
8386
_shard_shape: tuple[int, ...] | None = None
8487
_chunk_shape: tuple[int, ...] | None = None

0 commit comments

Comments
 (0)