Skip to content

Commit 1ae5e63

Browse files
committed
make strings work
1 parent 507161a commit 1ae5e63

File tree

4 files changed

+24
-11
lines changed

4 files changed

+24
-11
lines changed

src/zarr/codecs/legacy_vlen.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING
55

6+
import numpy as np
67
from numcodecs.vlen import VLenUTF8
78

89
from zarr.abc.codec import ArrayBytesCodec
910
from zarr.core.buffer import Buffer, NDBuffer
1011
from zarr.core.common import JSON, parse_named_configuration
1112
from zarr.registry import register_codec
13+
from zarr.strings import cast_to_string_dtype
1214

1315
if TYPE_CHECKING:
1416
from typing import Self
@@ -45,8 +47,11 @@ async def _decode_single(
4547

4648
raw_bytes = chunk_bytes.as_array_like()
4749
decoded = vlen_utf8_codec.decode(raw_bytes)
50+
assert decoded.dtype == np.object_
4851
decoded.shape = chunk_spec.shape
49-
return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded)
52+
# coming out of the code, we know this is safe, so don't issue a warning
53+
as_string_dtype = cast_to_string_dtype(decoded, safe=True)
54+
return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype)
5055

5156
async def _encode_single(
5257
self,

src/zarr/core/buffer/core.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,6 @@ class NDBuffer:
313313
"""
314314

315315
def __init__(self, array: NDArrayLike) -> None:
316-
# assert array.ndim > 0
317-
318-
# Commented this out because string arrays have dtype object
319-
# TODO: decide how to handle strings (e.g. numpy 2.0 StringDtype)
320316
# assert array.dtype != object
321317
self._data = array
322318

@@ -470,9 +466,12 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
470466
# Handle None fill_value for Zarr V2
471467
return False
472468
# use array_equal to obtain equal_nan=True functionality
469+
# Note from Ryan: doesn't this lead to a huge amount of unnecessary memory allocation on every single chunk?
470+
# Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value
471+
# every single time we have to write data?
473472
_data, other = np.broadcast_arrays(self._data, other)
474473
return np.array_equal(
475-
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "UST" else False
474+
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False
476475
)
477476

478477
def fill(self, value: Any) -> None:

src/zarr/core/metadata/v3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from zarr.core.config import config
3030
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
3131
from zarr.registry import get_codec_class
32+
from zarr.strings import STRING_DTYPE
3233

3334

3435
def parse_zarr_format(data: object) -> Literal[3]:
@@ -312,7 +313,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
312313
FLOAT = np.float16 | np.float32 | np.float64
313314
COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType
314315
COMPLEX = np.complex64 | np.complex128
315-
STRING = np.str_
316316

317317

318318
@overload
@@ -496,7 +496,7 @@ def to_numpy_shortname(self) -> str:
496496

497497
def to_numpy_dtype(self) -> np.dtype[Any]:
498498
if self == DataType.string:
499-
return np.dtypes.StringDType()
499+
return STRING_DTYPE
500500
else:
501501
return np.dtype(self.to_numpy_shortname())
502502

tests/v3/test_codecs/test_vlen.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88
from zarr.codecs import VLenUTF8Codec
99
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
1010
from zarr.store.common import StorePath
11+
from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING
12+
13+
numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
14+
expected_zarr_string_dtype: np.dtype[Any]
15+
if NUMPY_SUPPORTS_VLEN_STRING:
16+
numpy_str_dtypes.append(np.dtypes.StringDType)
17+
expected_zarr_string_dtype = np.dtypes.StringDType()
18+
else:
19+
expected_zarr_string_dtype = np.dtype("O")
1120

1221

1322
@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"])
14-
@pytest.mark.parametrize("dtype", [None, np.dtypes.StrDType])
23+
@pytest.mark.parametrize("dtype", numpy_str_dtypes)
1524
async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
1625
strings = ["hello", "world", "this", "is", "a", "test"]
1726
data = np.array(strings).reshape((2, 3))
@@ -32,11 +41,11 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
3241
a[:, :] = data
3342
assert np.array_equal(data, a[:, :])
3443
assert a.metadata.data_type == DataType.string
35-
assert a.dtype == np.dtypes.StringDType()
44+
assert a.dtype == expected_zarr_string_dtype
3645

3746
# test round trip
3847
b = Array.open(sp)
3948
assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy
4049
assert np.array_equal(data, b[:, :])
4150
assert b.metadata.data_type == DataType.string
42-
assert b.dtype == np.dtypes.StringDType()
51+
assert a.dtype == expected_zarr_string_dtype

0 commit comments

Comments
 (0)