Skip to content

Commit cae7055

Browse files
committed
more progress on typing; still not passing mypy
1 parent 3aeea1e commit cae7055

File tree

3 files changed

+98
-26
lines changed

3 files changed

+98
-26
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -335,30 +335,16 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
335335
# enum Literals can't be used in typing, so we have to restate all of the V3 dtypes as types
336336
# https://github.com/python/typing/issues/781
337337

338-
BOOL = np.bool_
339-
# BOOL_DTYPE = np.dtypes.BoolDType
340338
BOOL_DTYPE = Literal["bool"]
339+
BOOL = np.bool_
341340
INTEGER_DTYPE = Literal["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
342-
# INTEGER_DTYPE = (
343-
# np.dtypes.Int8DType
344-
# | np.dtypes.Int16DType
345-
# | np.dtypes.Int32DType
346-
# | np.dtypes.Int64DType
347-
# | np.dtypes.UInt8DType
348-
# | np.dtypes.UInt16DType
349-
# | np.dtypes.UInt32DType
350-
# | np.dtypes.UInt64DType
351-
# )
352341
INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64
353-
# FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType
354342
FLOAT_DTYPE = Literal["float16", "float32", "float64"]
355343
FLOAT = np.float16 | np.float32 | np.float64
356-
# COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType
357344
COMPLEX_DTYPE = Literal["complex64", "complex128"]
358345
COMPLEX = np.complex64 | np.complex128
359346
STRING_DTYPE = Literal["string"]
360347
STRING = np.str_
361-
# BYTES_DTYPE = np.dtypes.BytesDType
362348
BYTES_DTYPE = Literal["bytes"]
363349
BYTES = np.bytes_
364350

@@ -565,7 +551,7 @@ def to_numpy_shortname(self) -> str:
565551
}
566552
return data_type_to_numpy[self]
567553

568-
def to_numpy(self) -> np.dtype[np.generic]:
554+
def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[np.generic]:
569555
# note: it is not possible to round trip DataType <-> np.dtype
570556
# due to the fact that DataType.string and DataType.bytes both
571557
# generally return np.dtype("O") from this function, even though

src/zarr/strings.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,87 @@
1-
from typing import Any
1+
"""This module contains utilities for working with string arrays across
2+
different versions of Numpy.
3+
"""
4+
5+
from typing import Any, cast
26
from warnings import warn
37

48
import numpy as np
59

10+
# STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays
11+
# when reading data back from Zarr.
12+
# Any valid string-like datatype should be fine for *setting* data.
13+
14+
STRING_DTYPE: np.dtypes.StringDType | np.dtypes.ObjectDType
15+
NUMPY_SUPPORTS_VLEN_STRING: bool
16+
17+
18+
def cast_array(
19+
data: np.ndarray[Any, np.dtype[Any]],
20+
) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]:
21+
raise NotImplementedError
22+
23+
624
try:
7-
STRING_DTYPE = np.dtype("T")
25+
# this new vlen string dtype was added in NumPy 2.0
26+
STRING_DTYPE = np.dtypes.StringDType()
827
NUMPY_SUPPORTS_VLEN_STRING = True
9-
except TypeError:
10-
STRING_DTYPE = np.dtype("object")
28+
29+
def cast_array(
30+
data: np.ndarray[Any, np.dtype[Any]],
31+
) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]:
32+
out = data.astype(STRING_DTYPE, copy=False)
33+
return cast(np.ndarray[Any, np.dtypes.StringDType], out)
34+
35+
except AttributeError:
36+
# if not available, we fall back on an object array of strings, as in Zarr < 3
37+
STRING_DTYPE = np.dtypes.ObjectDType()
1138
NUMPY_SUPPORTS_VLEN_STRING = False
1239

40+
def cast_array(
41+
data: np.ndarray[Any, np.dtype[Any]],
42+
) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]:
43+
out = data.astype(STRING_DTYPE, copy=False)
44+
return cast(np.ndarray[Any, np.dtypes.ObjectDType], out)
45+
1346

1447
def cast_to_string_dtype(
1548
data: np.ndarray[Any, np.dtype[Any]], safe: bool = False
16-
) -> np.ndarray[Any, np.dtype[Any]]:
49+
) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]:
50+
"""Take any data and attempt to cast to to our preferred string dtype.
51+
52+
data : np.ndarray
53+
The data to cast
54+
55+
safe : bool
56+
If True, do not issue a warning if the data is cast from object to string dtype.
57+
58+
"""
1759
if np.issubdtype(data.dtype, np.str_):
18-
return data
60+
# legacy fixed-width string type (e.g. "<U10")
61+
return cast_array(data)
62+
# out = data.astype(STRING_DTYPE, copy=False)
63+
# return cast(np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType], out)
64+
if NUMPY_SUPPORTS_VLEN_STRING:
65+
if np.issubdtype(data.dtype, STRING_DTYPE):
66+
# already a valid string variable length string dtype
67+
return cast_array(data)
1968
if np.issubdtype(data.dtype, np.object_):
69+
# object arrays require more careful handling
2070
if NUMPY_SUPPORTS_VLEN_STRING:
2171
try:
2272
# cast to variable-length string dtype, fail if object contains non-string data
2373
# mypy says "error: Unexpected keyword argument "coerce" for "StringDType" [call-arg]"
24-
return data.astype(np.dtypes.StringDType(coerce=False), copy=False) # type: ignore[call-arg]
74+
# also: Value of type variable "_ScalarType" of "astype" of "ndarray" cannot be "str" [type-var]
75+
out = data.astype(np.dtypes.StringDType(coerce=False), copy=False) # type: ignore[call-arg,type-var]
76+
return cast_array(out)
2577
except ValueError as e:
2678
raise ValueError("Cannot cast object dtype to string dtype") from e
2779
else:
28-
out = data.astype(np.str_)
2980
if not safe:
3081
warn(
31-
f"Casted object dtype to string dtype {out.dtype}. To avoid this warning, "
82+
"Treating object array as valid string array. To avoid this warning, "
3283
"cast the data to a string dtype before passing to Zarr or upgrade to NumPy >= 2.",
3384
stacklevel=2,
3485
)
35-
return out
86+
return cast_array(data)
3687
raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype")

tests/test_strings.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Tests for the strings module."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING, STRING_DTYPE, cast_to_string_dtype
7+
8+
9+
def test_string_defaults() -> None:
10+
if NUMPY_SUPPORTS_VLEN_STRING:
11+
assert STRING_DTYPE == np.dtypes.StringDType()
12+
else:
13+
assert STRING_DTYPE == np.dtypes.ObjectDType()
14+
15+
16+
def test_cast_to_string_dtype() -> None:
17+
d1 = np.array(["a", "b", "c"])
18+
assert d1.dtype == np.dtype("<U1")
19+
d1s = cast_to_string_dtype(d1)
20+
assert d1s.dtype == STRING_DTYPE
21+
22+
with pytest.raises(ValueError, match="Cannot cast dtype |S1"):
23+
cast_to_string_dtype(d1.astype("|S1"))
24+
25+
if NUMPY_SUPPORTS_VLEN_STRING:
26+
assert cast_to_string_dtype(d1.astype("T")).dtype == STRING_DTYPE
27+
assert cast_to_string_dtype(d1.astype("O")).dtype == STRING_DTYPE
28+
with pytest.raises(ValueError, match="Cannot cast object dtype to string dtype"):
29+
cast_to_string_dtype(np.array([1, "b", "c"], dtype="O"))
30+
else:
31+
with pytest.warns():
32+
assert cast_to_string_dtype(d1.astype("O")).dtype == STRING_DTYPE
33+
with pytest.warns():
34+
assert cast_to_string_dtype(np.array([1, "b", "c"], dtype="O")).dtype == STRING_DTYPE
35+
assert cast_to_string_dtype(d1.astype("O"), safe=True).dtype == STRING_DTYPE

0 commit comments

Comments
 (0)