Skip to content

Commit 7c58f7a

Browse files
committed
rename fixed-length string dtypes, and be strict about the numpy object dtype (i.e., refuse to match it)
1 parent 0fc653f commit 7c58f7a

File tree

11 files changed

+102
-65
lines changed

11 files changed

+102
-65
lines changed

src/zarr/api/asynchronous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_warn_order_kwarg,
3232
_warn_write_empty_chunks_kwarg,
3333
)
34-
from zarr.core.dtype import get_data_type_from_native_dtype
34+
from zarr.core.dtype import ZDTypeLike, get_data_type_from_native_dtype, parse_data_type
3535
from zarr.core.group import (
3636
AsyncGroup,
3737
ConsolidatedMetadata,
@@ -843,7 +843,7 @@ async def create(
843843
shape: ChunkCoords | int,
844844
*, # Note: this is a change from v2
845845
chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True
846-
dtype: npt.DTypeLike | None = None,
846+
dtype: ZDTypeLike | None = None,
847847
compressor: CompressorLike = "auto",
848848
fill_value: Any | None = 0, # TODO: need type
849849
order: MemoryOrder | None = None,
@@ -990,11 +990,11 @@ async def create(
990990
_handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
991991
or _default_zarr_format()
992992
)
993-
dtype_wrapped = get_data_type_from_native_dtype(dtype)
993+
zdtype = parse_data_type(dtype, zarr_format=zarr_format)
994994
if zarr_format == 2:
995995
if chunks is None:
996996
chunks = shape
997-
default_filters, default_compressor = _get_default_chunk_encoding_v2(dtype_wrapped)
997+
default_filters, default_compressor = _get_default_chunk_encoding_v2(zdtype)
998998
if not filters:
999999
filters = default_filters # type: ignore[assignment]
10001000
if compressor == "auto":
@@ -1056,7 +1056,7 @@ async def create(
10561056
store_path,
10571057
shape=shape,
10581058
chunks=chunks,
1059-
dtype=dtype_wrapped,
1059+
dtype=zdtype,
10601060
compressor=compressor,
10611061
fill_value=fill_value,
10621062
overwrite=overwrite,

src/zarr/api/synchronous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def create(
599599
shape: ChunkCoords | int,
600600
*, # Note: this is a change from v2
601601
chunks: ChunkCoords | int | bool | None = None,
602-
dtype: npt.DTypeLike | None = None,
602+
dtype: ZDTypeLike | None = None,
603603
compressor: CompressorLike = "auto",
604604
fill_value: Any | None = None, # TODO: need type
605605
order: MemoryOrder | None = None,

src/zarr/core/dtype/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from zarr.core.dtype.npy.float import Float16, Float32, Float64
99
from zarr.core.dtype.npy.int import Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64
1010
from zarr.core.dtype.npy.sized import (
11-
FixedLengthAscii,
11+
FixedLengthASCII,
1212
FixedLengthBytes,
13-
FixedLengthUnicode,
13+
FixedLengthUTF32,
1414
Structured,
1515
)
1616
from zarr.core.dtype.npy.time import DateTime64, TimeDelta64
@@ -36,9 +36,9 @@
3636
"DataTypeRegistry",
3737
"DataTypeValidationError",
3838
"DateTime64",
39-
"FixedLengthAscii",
39+
"FixedLengthASCII",
4040
"FixedLengthBytes",
41-
"FixedLengthUnicode",
41+
"FixedLengthUTF32",
4242
"Float16",
4343
"Float32",
4444
"Float64",
@@ -72,8 +72,8 @@
7272
ComplexFloatDType = Complex64 | Complex128
7373
COMPLEX_FLOAT_DTYPE: Final = Complex64, Complex128
7474

75-
StringDType = FixedLengthUnicode | VariableLengthString | FixedLengthAscii
76-
STRING_DTYPE: Final = FixedLengthUnicode, VariableLengthString, FixedLengthAscii
75+
StringDType = FixedLengthUTF32 | VariableLengthString | FixedLengthASCII
76+
STRING_DTYPE: Final = FixedLengthUTF32, VariableLengthString, FixedLengthASCII
7777

7878
TimeDType = DateTime64 | TimeDelta64
7979
TIME_DTYPE: Final = DateTime64, TimeDelta64

src/zarr/core/dtype/npy/sized.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
@dataclass(frozen=True, kw_only=True)
23-
class FixedLengthAscii(ZDType[np.dtypes.BytesDType[int], np.bytes_], HasLength, HasItemSize):
23+
class FixedLengthASCII(ZDType[np.dtypes.BytesDType[int], np.bytes_], HasLength, HasItemSize):
2424
dtype_cls = np.dtypes.BytesDType
2525
_zarr_v3_name = "numpy.fixed_length_ascii"
2626

@@ -185,12 +185,12 @@ def item_size(self) -> int:
185185

186186

187187
@dataclass(frozen=True, kw_only=True)
188-
class FixedLengthUnicode(
188+
class FixedLengthUTF32(
189189
ZDType[np.dtypes.StrDType[int], np.str_], HasEndianness, HasLength, HasItemSize
190190
):
191191
dtype_cls = np.dtypes.StrDType
192-
_zarr_v3_name = "numpy.fixed_length_ucs4"
193-
code_point_bytes: ClassVar[int] = 4 # UCS4 is 4 bytes per code point
192+
_zarr_v3_name = "numpy.fixed_length_utf32"
193+
code_point_bytes: ClassVar[int] = 4 # utf32 is 4 bytes per code point
194194

195195
@classmethod
196196
def _from_dtype_unsafe(cls, dtype: TBaseDType) -> Self:

src/zarr/core/dtype/registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass, field
44
from typing import TYPE_CHECKING, Self
55

6+
import numpy as np
7+
68
from zarr.core.dtype.common import DataTypeValidationError
79

810
if TYPE_CHECKING:
@@ -38,6 +40,17 @@ def get(self, key: str) -> type[ZDType[TBaseDType, TBaseScalar]]:
3840

3941
def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]:
4042
self.lazy_load()
43+
if dtype == np.dtype("O"):
44+
msg = (
45+
"Data type resolution failed. "
46+
'Attempted to resolve a zarr data type from a numpy "Object" data type, which is '
47+
'ambiguous, as multiple zarr data types can be represented by the numpy "Object" '
48+
"data type. "
49+
"In this case you should construct your array by providing a specific Zarr data "
50+
'type. For a list of Zarr data types that are compatible with the numpy "Object"'
51+
"data type, see xxxxxxxxxxx"
52+
)
53+
raise ValueError(msg)
4154
for val in self.contents.values():
4255
try:
4356
return val.from_dtype(dtype)

src/zarr/core/metadata/dtype.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike
4040
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
4141
from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat
42+
from zarr.core.dtype.wrapper import ZDType
4243

4344

4445
async def parse_store(
@@ -417,3 +418,12 @@ def meta_from_array(
417418
chunk_key_encoding=chunk_key_encoding,
418419
dimension_names=dimension_names,
419420
)
421+
422+
423+
def skip_object_dtype(dtype: ZDType[Any, Any]) -> None:
424+
if dtype.dtype_cls is type(np.dtype("O")):
425+
msg = (
426+
f"{dtype} uses the numpy object data type, which is not a valid target for data "
427+
"type resolution"
428+
)
429+
pytest.skip(msg)

tests/test_array.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import zarr.api.asynchronous
2020
import zarr.api.synchronous as sync_api
21+
from tests.conftest import skip_object_dtype
2122
from zarr import Array, AsyncArray, Group
2223
from zarr.abc.store import Store
2324
from zarr.codecs import (
@@ -43,8 +44,8 @@
4344
from zarr.core.dtype import get_data_type_from_native_dtype
4445
from zarr.core.dtype.common import Endianness
4546
from zarr.core.dtype.npy.common import endianness_from_numpy_str
46-
from zarr.core.dtype.npy.float import Float64
47-
from zarr.core.dtype.npy.int import Int16
47+
from zarr.core.dtype.npy.float import Float32, Float64
48+
from zarr.core.dtype.npy.int import Int16, UInt8
4849
from zarr.core.dtype.npy.sized import (
4950
Structured,
5051
)
@@ -1009,9 +1010,11 @@ def test_dtype_forms(dtype: ZDType[Any, Any], store: Store, zarr_format: ZarrFor
10091010
"""
10101011
Test that the same array is produced from a ZDType instance, a numpy dtype, or a numpy string
10111012
"""
1013+
skip_object_dtype(dtype)
10121014
a = zarr.create_array(
10131015
store, name="a", shape=(5,), chunks=(5,), dtype=dtype, zarr_format=zarr_format
10141016
)
1017+
10151018
b = zarr.create_array(
10161019
store,
10171020
name="b",
@@ -1054,12 +1057,13 @@ def test_dtype_roundtrip(
10541057
"""
10551058
Test that creating an array, then opening it, gets the same array.
10561059
"""
1060+
skip_object_dtype(dtype)
10571061
a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype, zarr_format=zarr_format)
10581062
b = zarr.open_array(store)
10591063
assert a.dtype == b.dtype
10601064

10611065
@staticmethod
1062-
@pytest.mark.parametrize("dtype", ["uint8", "float32", "str", "U3", "S4", "V1"])
1066+
@pytest.mark.parametrize("dtype", ["uint8", "float32", "U3", "S4", "V1"])
10631067
@pytest.mark.parametrize(
10641068
"compressors",
10651069
[
@@ -1244,7 +1248,7 @@ async def test_invalid_v3_arguments(
12441248
zarr.create(store=store, dtype="uint8", shape=(10,), zarr_format=3, **kwargs)
12451249

12461250
@staticmethod
1247-
@pytest.mark.parametrize("dtype", ["uint8", "float32", "str"])
1251+
@pytest.mark.parametrize("dtype", ["uint8", "float32"])
12481252
@pytest.mark.parametrize(
12491253
"compressors",
12501254
[
@@ -1284,17 +1288,17 @@ async def test_v2_chunk_encoding(
12841288
assert arr.filters == filters_expected
12851289

12861290
@staticmethod
1287-
@pytest.mark.parametrize("dtype_str", ["uint8", "float32", "str"])
1291+
@pytest.mark.parametrize("dtype", [UInt8(), Float32(), VariableLengthString()])
12881292
async def test_default_filters_compressors(
1289-
store: MemoryStore, dtype_str: str, zarr_format: ZarrFormat
1293+
store: MemoryStore, dtype: UInt8 | Float32 | VariableLengthString, zarr_format: ZarrFormat
12901294
) -> None:
12911295
"""
12921296
Test that the default ``filters`` and ``compressors`` are used when ``create_array`` is invoked with ``filters`` and ``compressors`` unspecified.
12931297
"""
1294-
zdtype = get_data_type_from_native_dtype(dtype_str)
1298+
12951299
arr = await create_array(
12961300
store=store,
1297-
dtype=dtype_str,
1301+
dtype=dtype,
12981302
shape=(10,),
12991303
zarr_format=zarr_format,
13001304
)
@@ -1306,14 +1310,14 @@ async def test_default_filters_compressors(
13061310
compressors=sig.parameters["compressors"].default,
13071311
filters=sig.parameters["filters"].default,
13081312
serializer=sig.parameters["serializer"].default,
1309-
dtype=zdtype,
1313+
dtype=dtype,
13101314
)
13111315

13121316
elif zarr_format == 2:
13131317
default_filters, default_compressors = _parse_chunk_encoding_v2(
13141318
compressor=sig.parameters["compressors"].default,
13151319
filters=sig.parameters["filters"].default,
1316-
dtype=zdtype,
1320+
dtype=dtype,
13171321
)
13181322
if default_filters is None:
13191323
expected_filters = ()

tests/test_dtype/test_npy/test_sized.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from zarr.core.dtype.npy.float import Float16, Float64
99
from zarr.core.dtype.npy.int import Int32, Int64
1010
from zarr.core.dtype.npy.sized import (
11-
FixedLengthAscii,
11+
FixedLengthASCII,
1212
FixedLengthBytes,
13-
FixedLengthUnicode,
13+
FixedLengthUTF32,
1414
Structured,
1515
)
1616

1717

1818
class TestFixedLengthAscii(_TestZDType):
19-
test_cls = FixedLengthAscii
19+
test_cls = FixedLengthASCII
2020
valid_dtype = (np.dtype("|S10"), np.dtype("|S4"))
2121
invalid_dtype = (
2222
np.dtype(np.int8),
@@ -36,24 +36,24 @@ class TestFixedLengthAscii(_TestZDType):
3636
)
3737

3838
scalar_v2_params = (
39-
(FixedLengthAscii(length=0), ""),
40-
(FixedLengthAscii(length=2), "YWI="),
41-
(FixedLengthAscii(length=4), "YWJjZA=="),
39+
(FixedLengthASCII(length=0), ""),
40+
(FixedLengthASCII(length=2), "YWI="),
41+
(FixedLengthASCII(length=4), "YWJjZA=="),
4242
)
4343
scalar_v3_params = (
44-
(FixedLengthAscii(length=0), ""),
45-
(FixedLengthAscii(length=2), "YWI="),
46-
(FixedLengthAscii(length=4), "YWJjZA=="),
44+
(FixedLengthASCII(length=0), ""),
45+
(FixedLengthASCII(length=2), "YWI="),
46+
(FixedLengthASCII(length=4), "YWJjZA=="),
4747
)
4848
cast_value_params = (
49-
(FixedLengthAscii(length=0), "", np.bytes_("")),
50-
(FixedLengthAscii(length=2), "ab", np.bytes_("ab")),
51-
(FixedLengthAscii(length=4), "abcd", np.bytes_("abcd")),
49+
(FixedLengthASCII(length=0), "", np.bytes_("")),
50+
(FixedLengthASCII(length=2), "ab", np.bytes_("ab")),
51+
(FixedLengthASCII(length=4), "abcd", np.bytes_("abcd")),
5252
)
5353
item_size_params = (
54-
FixedLengthAscii(length=0),
55-
FixedLengthAscii(length=4),
56-
FixedLengthAscii(length=10),
54+
FixedLengthASCII(length=0),
55+
FixedLengthASCII(length=4),
56+
FixedLengthASCII(length=10),
5757
)
5858

5959

@@ -103,42 +103,42 @@ class TestFixedLengthBytes(_TestZDType):
103103
)
104104

105105

106-
class TestFixedLengthUnicode(_TestZDType):
107-
test_cls = FixedLengthUnicode
106+
class TestFixedLengthUTF32(_TestZDType):
107+
test_cls = FixedLengthUTF32
108108
valid_dtype = (np.dtype(">U10"), np.dtype("<U10"))
109109
invalid_dtype = (
110110
np.dtype(np.int8),
111111
np.dtype(np.float64),
112112
np.dtype("|S10"),
113113
)
114114
valid_json_v2 = (">U10", "<U10")
115-
valid_json_v3 = ({"name": "numpy.fixed_length_ucs4", "configuration": {"length_bytes": 320}},)
115+
valid_json_v3 = ({"name": "numpy.fixed_length_utf32", "configuration": {"length_bytes": 320}},)
116116
invalid_json_v2 = (
117117
"|U",
118118
"|S10",
119119
"|f8",
120120
)
121121
invalid_json_v3 = (
122-
{"name": "numpy.fixed_length_ucs4", "configuration": {"length_bits": 0}},
123-
{"name": "numpy.fixed_length_ucs4", "configuration": {"length_bits": "invalid"}},
122+
{"name": "numpy.fixed_length_utf32", "configuration": {"length_bits": 0}},
123+
{"name": "numpy.fixed_length_utf32", "configuration": {"length_bits": "invalid"}},
124124
)
125125

126-
scalar_v2_params = ((FixedLengthUnicode(length=0), ""), (FixedLengthUnicode(length=2), "hi"))
126+
scalar_v2_params = ((FixedLengthUTF32(length=0), ""), (FixedLengthUTF32(length=2), "hi"))
127127
scalar_v3_params = (
128-
(FixedLengthUnicode(length=0), ""),
129-
(FixedLengthUnicode(length=2), "hi"),
130-
(FixedLengthUnicode(length=4), "hihi"),
128+
(FixedLengthUTF32(length=0), ""),
129+
(FixedLengthUTF32(length=2), "hi"),
130+
(FixedLengthUTF32(length=4), "hihi"),
131131
)
132132

133133
cast_value_params = (
134-
(FixedLengthUnicode(length=0), "", np.str_("")),
135-
(FixedLengthUnicode(length=2), "hi", np.str_("hi")),
136-
(FixedLengthUnicode(length=4), "hihi", np.str_("hihi")),
134+
(FixedLengthUTF32(length=0), "", np.str_("")),
135+
(FixedLengthUTF32(length=2), "hi", np.str_("hi")),
136+
(FixedLengthUTF32(length=4), "hihi", np.str_("hihi")),
137137
)
138138
item_size_params = (
139-
FixedLengthUnicode(length=0),
140-
FixedLengthUnicode(length=4),
141-
FixedLengthUnicode(length=10),
139+
FixedLengthUTF32(length=0),
140+
FixedLengthUTF32(length=4),
141+
FixedLengthUTF32(length=10),
142142
)
143143

144144

@@ -180,7 +180,7 @@ class TestStructured(_TestZDType):
180180
),
181181
(
182182
"field2",
183-
{"name": "numpy.fixed_length_ucs4", "configuration": {"length_bytes": 32}},
183+
{"name": "numpy.fixed_length_utf32", "configuration": {"length_bytes": 32}},
184184
),
185185
]
186186
},

0 commit comments

Comments
 (0)