Skip to content

Commit 703e0e1

Browse files
committed
use wrap / unwrap instead of to_dtype / from_dtype; push into v2 codebase
1 parent 24930b3 commit 703e0e1

File tree

15 files changed

+147
-148
lines changed

15 files changed

+147
-148
lines changed

src/zarr/api/asynchronous.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata
3333
from zarr.core.metadata.v2 import _default_compressor, _default_filters
3434
from zarr.errors import NodeTypeValidationError
35+
from zarr.registry import get_data_type_from_numpy
3536
from zarr.storage._common import make_store_path
3637

3738
if TYPE_CHECKING:
@@ -428,11 +429,12 @@ async def save_array(
428429
shape = arr.shape
429430
chunks = getattr(arr, "chunks", None) # for array-likes with chunks attribute
430431
overwrite = kwargs.pop("overwrite", None) or _infer_overwrite(mode)
432+
zarr_dtype = get_data_type_from_numpy(arr.dtype)
431433
new = await AsyncArray._create(
432434
store_path,
433435
zarr_format=zarr_format,
434436
shape=shape,
435-
dtype=arr.dtype,
437+
dtype=zarr_dtype,
436438
chunks=chunks,
437439
overwrite=overwrite,
438440
**kwargs,
@@ -978,15 +980,14 @@ async def create(
978980
_handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
979981
or _default_zarr_format()
980982
)
981-
983+
dtype_wrapped = parse_dtype(dtype, zarr_format=zarr_format)
982984
if zarr_format == 2:
983985
if chunks is None:
984986
chunks = shape
985-
dtype = parse_dtype(dtype, zarr_format=zarr_format)
986987
if not filters:
987-
filters = _default_filters(dtype)
988+
filters = _default_filters(dtype_wrapped)
988989
if not compressor:
989-
compressor = _default_compressor(dtype)
990+
compressor = _default_compressor(dtype_wrapped)
990991
elif zarr_format == 3 and chunk_shape is None: # type: ignore[redundant-expr]
991992
if chunks is not None:
992993
chunk_shape = chunks
@@ -1051,7 +1052,7 @@ async def create(
10511052
store_path,
10521053
shape=shape,
10531054
chunks=chunks,
1054-
dtype=dtype,
1055+
dtype=dtype_wrapped,
10551056
compressor=compressor,
10561057
fill_value=fill_value,
10571058
overwrite=overwrite,

src/zarr/codecs/_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ async def _decode_single(
4848
# segfaults and other bad things happening
4949
if chunk_spec.dtype != object:
5050
try:
51-
chunk = chunk.view(chunk_spec.dtype)
51+
chunk = chunk.view(chunk_spec.dtype.unwrap())
5252
except TypeError:
5353
# this will happen if the dtype of the chunk
5454
# does not match the dtype of the array spec i.g. if
5555
# the dtype of the chunk_spec is a string dtype, but the chunk
5656
# is an object array. In this case, we need to convert the object
5757
# array to the correct dtype.
5858

59-
chunk = np.array(chunk).astype(chunk_spec.dtype)
59+
chunk = np.array(chunk).astype(chunk_spec.dtype.unwrap())
6060

6161
elif chunk.dtype != object:
6262
# If we end up here, someone must have hacked around with the filters.
@@ -80,7 +80,7 @@ async def _encode_single(
8080
chunk = chunk_array.as_ndarray_like()
8181

8282
# ensure contiguous and correct order
83-
chunk = chunk.astype(chunk_spec.dtype, order=chunk_spec.order, copy=False)
83+
chunk = chunk.astype(chunk_spec.dtype.unwrap(), order=chunk_spec.order, copy=False)
8484

8585
# apply filters
8686
if self.filters:

src/zarr/codecs/bytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def to_dict(self) -> dict[str, JSON]:
5656
return {"name": "bytes", "configuration": {"endian": self.endian.value}}
5757

5858
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
59-
if array_spec.dtype.itemsize == 0:
59+
if array_spec.dtype.unwrap().itemsize == 0:
6060
if self.endian is not None:
6161
return replace(self, endian=None)
6262
elif self.endian is None:

src/zarr/core/array.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
ArrayV3MetadataDict,
9999
T_ArrayMetadata,
100100
)
101-
from zarr.core.metadata.dtype import DTypeWrapper
101+
from zarr.core.metadata.dtype import DTypeWrapper, VariableLengthString
102102
from zarr.core.metadata.v2 import (
103103
_default_compressor,
104104
_default_filters,
@@ -549,7 +549,7 @@ async def _create(
549549
*,
550550
# v2 and v3
551551
shape: ShapeLike,
552-
dtype: npt.DTypeLike,
552+
dtype: npt.DTypeLike[Any],
553553
zarr_format: ZarrFormat = 3,
554554
fill_value: Any | None = None,
555555
attributes: dict[str, JSON] | None = None,
@@ -578,18 +578,22 @@ async def _create(
578578
See :func:`AsyncArray.create` for more details.
579579
Deprecated in favor of :func:`zarr.api.asynchronous.create_array`.
580580
"""
581+
# TODO: delete this and be more strict about where parsing occurs
582+
if not isinstance(dtype, DTypeWrapper):
583+
dtype_parsed = get_data_type_from_numpy(np.dtype(dtype))
584+
else:
585+
dtype_parsed = dtype
581586
store_path = await make_store_path(store)
582587

583-
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
584588
shape = parse_shapelike(shape)
585589

586590
if chunks is not None and chunk_shape is not None:
587591
raise ValueError("Only one of chunk_shape or chunks can be provided.")
588592

589593
if chunks:
590-
_chunks = normalize_chunks(chunks, shape, dtype_parsed.itemsize)
594+
_chunks = normalize_chunks(chunks, shape, dtype_parsed.unwrap().itemsize)
591595
else:
592-
_chunks = normalize_chunks(chunk_shape, shape, dtype_parsed.itemsize)
596+
_chunks = normalize_chunks(chunk_shape, shape, dtype_parsed.unwrap().itemsize)
593597
config_parsed = parse_array_config(config)
594598

595599
result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
@@ -666,7 +670,7 @@ async def _create(
666670
@staticmethod
667671
def _create_metadata_v3(
668672
shape: ShapeLike,
669-
dtype: np.dtype[Any],
673+
dtype: DTypeWrapper[Any, Any],
670674
chunk_shape: ChunkCoords,
671675
fill_value: Any | None = None,
672676
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
@@ -694,19 +698,16 @@ def _create_metadata_v3(
694698
stacklevel=2,
695699
)
696700

697-
# resolve the numpy dtype into zarr v3 datatype
698-
zarr_data_type = get_data_type_from_numpy(dtype)
699-
700701
if fill_value is None:
701702
# v3 spec will not allow a null fill value
702-
fill_value_parsed = zarr_data_type.default_value
703+
fill_value_parsed = dtype.default_value
703704
else:
704705
fill_value_parsed = fill_value
705706

706707
chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape)
707708
return ArrayV3Metadata(
708709
shape=shape,
709-
data_type=zarr_data_type,
710+
data_type=dtype,
710711
chunk_grid=chunk_grid_parsed,
711712
chunk_key_encoding=chunk_key_encoding_parsed,
712713
fill_value=fill_value_parsed,
@@ -769,7 +770,7 @@ async def _create_v3(
769770
@staticmethod
770771
def _create_metadata_v2(
771772
shape: ChunkCoords,
772-
dtype: np.dtype[Any],
773+
dtype: DTypeWrapper[Any, Any],
773774
chunks: ChunkCoords,
774775
order: MemoryOrder,
775776
dimension_separator: Literal[".", "/"] | None = None,
@@ -781,10 +782,8 @@ def _create_metadata_v2(
781782
if dimension_separator is None:
782783
dimension_separator = "."
783784

784-
dtype = parse_dtype(dtype, zarr_format=2)
785-
786785
# inject VLenUTF8 for str dtype if not already present
787-
if np.issubdtype(dtype, np.str_):
786+
if isinstance(dtype, VariableLengthString):
788787
filters = filters or []
789788
from numcodecs.vlen import VLenUTF8
790789

@@ -793,7 +792,7 @@ def _create_metadata_v2(
793792

794793
return ArrayV2Metadata(
795794
shape=shape,
796-
dtype=np.dtype(dtype),
795+
dtype=dtype,
797796
chunks=chunks,
798797
order=order,
799798
dimension_separator=dimension_separator,
@@ -2046,7 +2045,7 @@ def dtype(self) -> np.dtype[Any]:
20462045
np.dtype
20472046
The NumPy data type.
20482047
"""
2049-
return self._async_array.dtype
2048+
return self._async_array.dtype.unwrap()
20502049

20512050
@property
20522051
def attrs(self) -> Attributes:
@@ -3919,7 +3918,7 @@ async def init_array(
39193918

39203919
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
39213920

3922-
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
3921+
dtype_wrapped = parse_dtype(dtype, zarr_format=zarr_format)
39233922
shape_parsed = parse_shapelike(shape)
39243923
chunk_key_encoding_parsed = _parse_chunk_key_encoding(
39253924
chunk_key_encoding, zarr_format=zarr_format
@@ -3934,7 +3933,10 @@ async def init_array(
39343933
await ensure_no_existing_node(store_path, zarr_format=zarr_format)
39353934

39363935
shard_shape_parsed, chunk_shape_parsed = _auto_partition(
3937-
array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed
3936+
array_shape=shape_parsed,
3937+
shard_shape=shards,
3938+
chunk_shape=chunks,
3939+
item_size=dtype_wrapped.unwrap().itemsize,
39383940
)
39393941
chunks_out: tuple[int, ...]
39403942
meta: ArrayV2Metadata | ArrayV3Metadata
@@ -3950,9 +3952,8 @@ async def init_array(
39503952
raise ValueError("Zarr format 2 arrays do not support `serializer`.")
39513953

39523954
filters_parsed, compressor_parsed = _parse_chunk_encoding_v2(
3953-
compressor=compressors, filters=filters, dtype=np.dtype(dtype)
3955+
compressor=compressors, filters=filters, dtype=dtype_wrapped
39543956
)
3955-
39563957
if dimension_names is not None:
39573958
raise ValueError("Zarr format 2 arrays do not support dimension names.")
39583959
if order is None:
@@ -3962,7 +3963,7 @@ async def init_array(
39623963

39633964
meta = AsyncArray._create_metadata_v2(
39643965
shape=shape_parsed,
3965-
dtype=dtype_parsed,
3966+
dtype=dtype_wrapped,
39663967
chunks=chunk_shape_parsed,
39673968
dimension_separator=chunk_key_encoding_parsed.separator,
39683969
fill_value=fill_value,
@@ -3976,7 +3977,7 @@ async def init_array(
39763977
compressors=compressors,
39773978
filters=filters,
39783979
serializer=serializer,
3979-
dtype=dtype_parsed,
3980+
dtype=dtype_wrapped,
39803981
)
39813982
sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes))
39823983
codecs_out: tuple[Codec, ...]
@@ -3991,7 +3992,7 @@ async def init_array(
39913992
)
39923993
sharding_codec.validate(
39933994
shape=chunk_shape_parsed,
3994-
dtype=dtype_parsed,
3995+
dtype=dtype_wrapped,
39953996
chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed),
39963997
)
39973998
codecs_out = (sharding_codec,)
@@ -4002,7 +4003,7 @@ async def init_array(
40024003

40034004
meta = AsyncArray._create_metadata_v3(
40044005
shape=shape_parsed,
4005-
dtype=dtype_parsed,
4006+
dtype=dtype_wrapped,
40064007
fill_value=fill_value,
40074008
chunk_shape=chunks_out,
40084009
chunk_key_encoding=chunk_key_encoding_parsed,
@@ -4210,12 +4211,11 @@ def _parse_chunk_key_encoding(
42104211

42114212

42124213
def _get_default_chunk_encoding_v3(
4213-
np_dtype: np.dtype[Any],
4214+
dtype: DTypeWrapper[Any, Any],
42144215
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
42154216
"""
42164217
Get the default ArrayArrayCodecs, ArrayBytesCodec, and BytesBytesCodec for a given dtype.
42174218
"""
4218-
dtype = get_data_type_from_numpy(np_dtype)
42194219

42204220
default_filters = zarr_config.get("array.v3_default_filters").get(dtype.kind)
42214221
default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype.kind)
@@ -4229,14 +4229,14 @@ def _get_default_chunk_encoding_v3(
42294229

42304230

42314231
def _get_default_chunk_encoding_v2(
4232-
np_dtype: np.dtype[Any],
4232+
dtype: DTypeWrapper[Any, Any],
42334233
) -> tuple[tuple[numcodecs.abc.Codec, ...] | None, numcodecs.abc.Codec | None]:
42344234
"""
42354235
Get the default chunk encoding for Zarr format 2 arrays, given a dtype
42364236
"""
42374237

4238-
compressor_dict = _default_compressor(np_dtype)
4239-
filter_dicts = _default_filters(np_dtype)
4238+
compressor_dict = _default_compressor(dtype)
4239+
filter_dicts = _default_filters(dtype)
42404240

42414241
compressor = None
42424242
if compressor_dict is not None:
@@ -4253,13 +4253,12 @@ def _parse_chunk_encoding_v2(
42534253
*,
42544254
compressor: CompressorsLike,
42554255
filters: FiltersLike,
4256-
dtype: np.dtype[Any],
4256+
dtype: DTypeWrapper[Any, Any],
42574257
) -> tuple[tuple[numcodecs.abc.Codec, ...] | None, numcodecs.abc.Codec | None]:
42584258
"""
42594259
Generate chunk encoding classes for Zarr format 2 arrays with optional defaults.
42604260
"""
42614261
default_filters, default_compressor = _get_default_chunk_encoding_v2(dtype)
4262-
42634262
_filters: tuple[numcodecs.abc.Codec, ...] | None
42644263
_compressor: numcodecs.abc.Codec | None
42654264

src/zarr/core/array_spec.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from dataclasses import dataclass, fields
44
from typing import TYPE_CHECKING, Any, Literal, Self, TypedDict, cast
55

6-
import numpy as np
7-
86
from zarr.core.common import (
97
MemoryOrder,
108
parse_bool,
@@ -13,10 +11,14 @@
1311
parse_shapelike,
1412
)
1513
from zarr.core.config import config as zarr_config
14+
from zarr.core.metadata.dtype import DTypeWrapper
15+
from zarr.registry import get_data_type_from_numpy
1616

1717
if TYPE_CHECKING:
1818
from typing import NotRequired
1919

20+
import numpy.typing as npt
21+
2022
from zarr.core.buffer import BufferPrototype
2123
from zarr.core.common import ChunkCoords
2224

@@ -90,21 +92,25 @@ def parse_array_config(data: ArrayConfigLike | None) -> ArrayConfig:
9092
@dataclass(frozen=True)
9193
class ArraySpec:
9294
shape: ChunkCoords
93-
dtype: np.dtype[Any]
95+
dtype: DTypeWrapper[Any, Any]
9496
fill_value: Any
9597
config: ArrayConfig
9698
prototype: BufferPrototype
9799

98100
def __init__(
99101
self,
100102
shape: ChunkCoords,
101-
dtype: np.dtype[Any],
103+
dtype: npt.DtypeLike | DTypeWrapper[Any, Any],
102104
fill_value: Any,
103105
config: ArrayConfig,
104106
prototype: BufferPrototype,
105107
) -> None:
106108
shape_parsed = parse_shapelike(shape)
107-
dtype_parsed = np.dtype(dtype)
109+
if not isinstance(dtype, DTypeWrapper):
110+
dtype_parsed = get_data_type_from_numpy(dtype)
111+
else:
112+
dtype_parsed = dtype
113+
108114
fill_value_parsed = parse_fill_value(fill_value)
109115

110116
object.__setattr__(self, "shape", shape_parsed)

src/zarr/core/buffer/cpu.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy.typing as npt
1111

1212
from zarr.core.buffer import core
13+
from zarr.core.metadata.dtype import DTypeWrapper
1314
from zarr.registry import (
1415
register_buffer,
1516
register_ndbuffer,
@@ -150,14 +151,18 @@ def create(
150151
cls,
151152
*,
152153
shape: Iterable[int],
153-
dtype: npt.DTypeLike,
154+
dtype: DTypeWrapper[Any, Any],
154155
order: Literal["C", "F"] = "C",
155156
fill_value: Any | None = None,
156157
) -> Self:
157158
if fill_value is None:
158159
return cls(np.zeros(shape=tuple(shape), dtype=dtype, order=order))
159160
else:
160-
return cls(np.full(shape=tuple(shape), fill_value=fill_value, dtype=dtype, order=order))
161+
return cls(
162+
np.full(
163+
shape=tuple(shape), fill_value=fill_value, dtype=dtype.unwrap(), order=order
164+
)
165+
)
161166

162167
@classmethod
163168
def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self:

0 commit comments

Comments
 (0)