Skip to content

Commit 4529eae

Browse files
committed
Merge branch 'v3' of https://github.com/zarr-developers/zarr-python into refactor/rename-v2-metadata-fields
2 parents 02832b9 + aa46b45 commit 4529eae

File tree

20 files changed

+692
-126
lines changed

20 files changed

+692
-126
lines changed

src/zarr/abc/store.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,10 @@ def __exit__(
6969
async def _open(self) -> None:
7070
if self._is_open:
7171
raise ValueError("store is already open")
72-
if not await self.empty():
73-
if self.mode.update or self.mode.readonly:
74-
pass
75-
elif self.mode.overwrite:
76-
await self.clear()
77-
else:
78-
raise FileExistsError("Store already exists")
72+
if self.mode.str == "w":
73+
await self.clear()
74+
elif self.mode.str == "w-" and not await self.empty():
75+
raise FileExistsError("Store already exists")
7976
self._is_open = True
8077

8178
async def _ensure_open(self) -> None:

src/zarr/codecs/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
import numpy as np
7+
38
from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
49
from zarr.codecs.bytes import BytesCodec, Endian
510
from zarr.codecs.crc32c_ import Crc32cCodec
611
from zarr.codecs.gzip import GzipCodec
712
from zarr.codecs.pipeline import BatchedCodecPipeline
813
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
914
from zarr.codecs.transpose import TransposeCodec
15+
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
1016
from zarr.codecs.zstd import ZstdCodec
17+
from zarr.core.metadata.v3 import DataType
1118

1219
__all__ = [
1320
"BatchedCodecPipeline",
@@ -21,5 +28,19 @@
2128
"ShardingCodec",
2229
"ShardingCodecIndexLocation",
2330
"TransposeCodec",
31+
"VLenUTF8Codec",
32+
"VLenBytesCodec",
2433
"ZstdCodec",
2534
]
35+
36+
37+
def _get_default_array_bytes_codec(
38+
np_dtype: np.dtype[Any],
39+
) -> BytesCodec | VLenUTF8Codec | VLenBytesCodec:
40+
dtype = DataType.from_numpy(np_dtype)
41+
if dtype == DataType.string:
42+
return VLenUTF8Codec()
43+
elif dtype == DataType.bytes:
44+
return VLenBytesCodec()
45+
else:
46+
return BytesCodec()

src/zarr/codecs/vlen_utf8.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING
5+
6+
import numpy as np
7+
from numcodecs.vlen import VLenBytes, VLenUTF8
8+
9+
from zarr.abc.codec import ArrayBytesCodec
10+
from zarr.core.buffer import Buffer, NDBuffer
11+
from zarr.core.common import JSON, parse_named_configuration
12+
from zarr.core.strings import cast_to_string_dtype
13+
from zarr.registry import register_codec
14+
15+
if TYPE_CHECKING:
16+
from typing import Self
17+
18+
from zarr.core.array_spec import ArraySpec
19+
20+
21+
# can use a global because there are no parameters
22+
_vlen_utf8_codec = VLenUTF8()
23+
_vlen_bytes_codec = VLenBytes()
24+
25+
26+
@dataclass(frozen=True)
27+
class VLenUTF8Codec(ArrayBytesCodec):
28+
@classmethod
29+
def from_dict(cls, data: dict[str, JSON]) -> Self:
30+
_, configuration_parsed = parse_named_configuration(
31+
data, "vlen-utf8", require_configuration=False
32+
)
33+
configuration_parsed = configuration_parsed or {}
34+
return cls(**configuration_parsed)
35+
36+
def to_dict(self) -> dict[str, JSON]:
37+
return {"name": "vlen-utf8", "configuration": {}}
38+
39+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
40+
return self
41+
42+
async def _decode_single(
43+
self,
44+
chunk_bytes: Buffer,
45+
chunk_spec: ArraySpec,
46+
) -> NDBuffer:
47+
assert isinstance(chunk_bytes, Buffer)
48+
49+
raw_bytes = chunk_bytes.as_array_like()
50+
decoded = _vlen_utf8_codec.decode(raw_bytes)
51+
assert decoded.dtype == np.object_
52+
decoded.shape = chunk_spec.shape
53+
# coming out of the code, we know this is safe, so don't issue a warning
54+
as_string_dtype = cast_to_string_dtype(decoded, safe=True)
55+
return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype)
56+
57+
async def _encode_single(
58+
self,
59+
chunk_array: NDBuffer,
60+
chunk_spec: ArraySpec,
61+
) -> Buffer | None:
62+
assert isinstance(chunk_array, NDBuffer)
63+
return chunk_spec.prototype.buffer.from_bytes(
64+
_vlen_utf8_codec.encode(chunk_array.as_numpy_array())
65+
)
66+
67+
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
68+
# what is input_byte_length for an object dtype?
69+
raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs")
70+
71+
72+
@dataclass(frozen=True)
73+
class VLenBytesCodec(ArrayBytesCodec):
74+
@classmethod
75+
def from_dict(cls, data: dict[str, JSON]) -> Self:
76+
_, configuration_parsed = parse_named_configuration(
77+
data, "vlen-bytes", require_configuration=False
78+
)
79+
configuration_parsed = configuration_parsed or {}
80+
return cls(**configuration_parsed)
81+
82+
def to_dict(self) -> dict[str, JSON]:
83+
return {"name": "vlen-bytes", "configuration": {}}
84+
85+
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
86+
return self
87+
88+
async def _decode_single(
89+
self,
90+
chunk_bytes: Buffer,
91+
chunk_spec: ArraySpec,
92+
) -> NDBuffer:
93+
assert isinstance(chunk_bytes, Buffer)
94+
95+
raw_bytes = chunk_bytes.as_array_like()
96+
decoded = _vlen_bytes_codec.decode(raw_bytes)
97+
assert decoded.dtype == np.object_
98+
decoded.shape = chunk_spec.shape
99+
return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded)
100+
101+
async def _encode_single(
102+
self,
103+
chunk_array: NDBuffer,
104+
chunk_spec: ArraySpec,
105+
) -> Buffer | None:
106+
assert isinstance(chunk_array, NDBuffer)
107+
return chunk_spec.prototype.buffer.from_bytes(
108+
_vlen_bytes_codec.encode(chunk_array.as_numpy_array())
109+
)
110+
111+
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
112+
# what is input_byte_length for an object dtype?
113+
raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs")
114+
115+
116+
register_codec("vlen-utf8", VLenUTF8Codec)
117+
register_codec("vlen-bytes", VLenBytesCodec)

src/zarr/core/array.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from zarr._compat import _deprecate_positional_args
1313
from zarr.abc.store import Store, set_or_delete
14-
from zarr.codecs import BytesCodec
14+
from zarr.codecs import _get_default_array_bytes_codec
1515
from zarr.codecs._v2 import V2Compressor, V2Filters
1616
from zarr.core.attributes import Attributes
1717
from zarr.core.buffer import (
@@ -480,7 +480,11 @@ async def _create_v3(
480480
await ensure_no_existing_node(store_path, zarr_format=3)
481481

482482
shape = parse_shapelike(shape)
483-
codecs = list(codecs) if codecs is not None else [BytesCodec()]
483+
codecs = (
484+
list(codecs)
485+
if codecs is not None
486+
else [_get_default_array_bytes_codec(np.dtype(dtype))]
487+
)
484488

485489
if chunk_key_encoding is None:
486490
chunk_key_encoding = ("default", "/")

src/zarr/core/buffer/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,6 @@ class NDBuffer:
313313
"""
314314

315315
def __init__(self, array: NDArrayLike) -> None:
316-
# assert array.ndim > 0
317-
assert array.dtype != object
318316
self._data = array
319317

320318
@classmethod
@@ -467,9 +465,11 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
467465
# Handle None fill_value for Zarr V2
468466
return False
469467
# use array_equal to obtain equal_nan=True functionality
468+
# Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value
469+
# every single time we have to write data?
470470
_data, other = np.broadcast_arrays(self._data, other)
471471
return np.array_equal(
472-
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
472+
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False
473473
)
474474

475475
def fill(self, value: Any) -> None:

src/zarr/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def reset(self) -> None:
5858
"crc32c": "zarr.codecs.crc32c_.Crc32cCodec",
5959
"sharding_indexed": "zarr.codecs.sharding.ShardingCodec",
6060
"transpose": "zarr.codecs.transpose.TransposeCodec",
61+
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
62+
"vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec",
6163
},
6264
"buffer": "zarr.core.buffer.cpu.Buffer",
6365
"ndbuffer": "zarr.core.buffer.cpu.NDBuffer",

src/zarr/core/group.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from dataclasses import asdict, dataclass, field, fields, replace
7-
from typing import TYPE_CHECKING, Literal, cast, overload
7+
from typing import TYPE_CHECKING, Literal, TypeVar, cast, overload
88

99
import numpy as np
1010
import numpy.typing as npt
@@ -44,6 +44,8 @@
4444

4545
logger = logging.getLogger("zarr.group")
4646

47+
DefaultT = TypeVar("DefaultT")
48+
4749

4850
def parse_zarr_format(data: Any) -> ZarrFormat:
4951
if data in (2, 3):
@@ -294,6 +296,28 @@ async def delitem(self, key: str) -> None:
294296
else:
295297
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
296298

299+
async def get(
300+
self, key: str, default: DefaultT | None = None
301+
) -> AsyncArray | AsyncGroup | DefaultT | None:
302+
"""Obtain a group member, returning default if not found.
303+
304+
Parameters
305+
----------
306+
key : string
307+
Group member name.
308+
default : object
309+
Default value to return if key is not found (default: None).
310+
311+
Returns
312+
-------
313+
object
314+
Group member (AsyncArray or AsyncGroup) or default if not found.
315+
"""
316+
try:
317+
return await self.getitem(key)
318+
except KeyError:
319+
return default
320+
297321
async def _save_metadata(self, ensure_parents: bool = False) -> None:
298322
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
299323
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
@@ -856,6 +880,26 @@ def __getitem__(self, path: str) -> Array | Group:
856880
else:
857881
return Group(obj)
858882

883+
def get(self, path: str, default: DefaultT | None = None) -> Array | Group | DefaultT | None:
884+
"""Obtain a group member, returning default if not found.
885+
886+
Parameters
887+
----------
888+
key : string
889+
Group member name.
890+
default : object
891+
Default value to return if key is not found (default: None).
892+
893+
Returns
894+
-------
895+
object
896+
Group member (Array or Group) or default if not found.
897+
"""
898+
try:
899+
return self[path]
900+
except KeyError:
901+
return default
902+
859903
def __delitem__(self, key: str) -> None:
860904
self._sync(self._async_group.delitem(key))
861905

0 commit comments

Comments
 (0)