Skip to content

Commit 24094d6

Browse files
authored
Merge pull request #77 from d-v-b/fill-value-fix
- Ensure that the JSON form of an array's fill value is used for the `fill_value` fields for v2 and v3 `ArraySpec` classes. - Add the parametric equality method `like` to v3 models - Refactor `to_zarr` for v2 and v3 models to use identical logic. Note that this is a breaking change, as `to_zarr` for an `ArraySpec` longer takes `**kwargs`. In its place, the `config` keyword argument has been added, that specifies the runtime configuration of the array. - Ensure that `str` is allowed in the union of possible `fill_value` types. This supports string dtypes - Run tests against zarr-python 3.0.10 and 3.1.0
2 parents 815f5b7 + 016b386 commit 24094d6

File tree

9 files changed

+388
-113
lines changed

9 files changed

+388
-113
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
pip install hatch
3939
- name: Run Tests
4040
run: |
41+
hatch run test.py${{ matrix.python-version }}-${{ matrix.zarr-version }}:list-env
4142
hatch run test.py${{ matrix.python-version }}-${{ matrix.zarr-version }}:test-cov
4243
- name: Upload coverage
4344
uses: codecov/codecov-action@v5

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ print(spec.model_dump())
3939
'order': 'C',
4040
'filters': None,
4141
'dimension_separator': '.',
42-
'compressor': {'id': 'zstd', 'level': 0, 'checksum': False},
42+
'compressor': {'id': 'zstd', 'level': 0},
4343
}
4444
},
4545
}

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ build.hooks.vcs.version-file = "src/pydantic_zarr/_version.py"
4444

4545
[tool.hatch.envs.test]
4646
features = ["test"]
47+
dependencies = [
48+
"zarr~={matrix:zarr}",
49+
]
4750

4851
[tool.hatch.envs.test.scripts]
4952
test = "pytest tests/test_pydantic_zarr/"
5053
test-cov = "pytest --cov-config=pyproject.toml --cov=pkg --cov-report html --cov=src tests/test_pydantic_zarr"
54+
list-env = "pip list"
5155

5256
[[tool.hatch.envs.test.matrix]]
5357
python = ["3.11", "3.12", "3.13"]
54-
zarr_python = ["3.0.10", "3.1.0"]
58+
zarr = ["3.0.10", "3.1.0"]
5559

5660
[tool.hatch.envs.docs]
5761
features = ['docs']

src/pydantic_zarr/core.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,38 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping
3+
from collections.abc import Mapping, Sequence
44
from typing import (
55
TYPE_CHECKING,
66
Any,
77
Literal,
88
TypeAlias,
9+
overload,
910
)
1011

1112
import numpy as np
1213
import numpy.typing as npt
13-
import zarr
1414
from pydantic import BaseModel, ConfigDict
15+
from zarr.core.sync import sync
1516
from zarr.core.sync_group import get_node
17+
from zarr.storage._common import make_store_path
1618

1719
if TYPE_CHECKING:
18-
from zarr.abc.store import Store
20+
import zarr
21+
from zarr.storage._common import StoreLike
1922

2023
IncEx: TypeAlias = set[int] | set[str] | dict[int, Any] | dict[str, Any] | None
2124

2225
AccessMode: TypeAlias = Literal["w", "w+", "r", "a"]
2326

2427

28+
@overload
29+
def tuplify_json(obj: Mapping) -> Mapping: ...
30+
31+
32+
@overload
33+
def tuplify_json(obj: list) -> tuple: ...
34+
35+
2536
def tuplify_json(obj: object) -> object:
2637
"""
2738
Recursively converts lists within a Python object to tuples.
@@ -38,21 +49,38 @@ class StrictBase(BaseModel):
3849
model_config = ConfigDict(frozen=True, extra="forbid")
3950

4051

41-
def stringify_dtype(value: npt.DTypeLike) -> str:
52+
def parse_dtype_v2(value: npt.DTypeLike) -> str | list[tuple[Any, ...]]:
4253
"""
43-
Convert a `numpy.dtype` object into a `str`.
54+
Convert the input to a NumPy dtype and either return the ``str`` attribute of that
55+
object or, if the dtype is a structured dtype, return the fields of that dtype as a list
56+
of tuples.
4457
4558
Parameters
4659
----------
47-
value : `npt.DTypeLike`
48-
Some object that can be coerced to a numpy dtype
60+
value : npt.DTypeLike
61+
A value that can be converted to a NumPy dtype.
4962
5063
Returns
5164
-------
5265
53-
A numpy dtype string representation of `value`.
66+
A Zarr V2-compatible encoding of the dtype.
67+
68+
References
69+
----------
70+
See the [Zarr V2 specification](https://zarr-specs.readthedocs.io/en/latest/v2/v2.0.html#data-type-encoding)
71+
for more details on this encoding of data types.
5472
"""
55-
return np.dtype(value).str
73+
# Assume that a non-string sequence represents a the Zarr V2 JSON form of a structured dtype.
74+
if isinstance(value, Sequence) and not isinstance(value, str):
75+
return [tuple(v) for v in value]
76+
else:
77+
np_dtype = np.dtype(value)
78+
if np_dtype.fields is not None:
79+
# This is a structured dtype, which must be converted to a list of tuples. Note that
80+
# this function recurses, because a structured dtype is parametrized by other dtypes.
81+
return [(k, parse_dtype_v2(v[0])) for k, v in np_dtype.fields.items()]
82+
else:
83+
return np_dtype.str
5684

5785

5886
def ensure_member_name(data: Any) -> str:
@@ -92,15 +120,16 @@ def model_like(a: BaseModel, b: BaseModel, exclude: IncEx = None, include: IncEx
92120

93121
# TODO: expose contains_array and contains_group as public functions in zarr-python
94122
# and replace these custom implementations
95-
def contains_array(store: Store, path: str) -> bool:
96-
try:
97-
return isinstance(get_node(store, path, zarr_format=2), zarr.Array)
98-
except FileNotFoundError:
99-
return False
100-
101-
102-
def contains_group(store: Store, path: str) -> bool:
123+
def maybe_node(
124+
store: StoreLike, path: str, *, zarr_format: Literal[2, 3]
125+
) -> zarr.Array | zarr.Group | None:
126+
"""
127+
Return the array or group found at the store / path, if an array or group exists there.
128+
Otherwise return None.
129+
"""
130+
# convert the storelike store argument to a Zarr store
131+
spath = sync(make_store_path(store, path=path))
103132
try:
104-
return isinstance(get_node(store, path, zarr_format=2), zarr.Group)
133+
return get_node(spath.store, spath.path, zarr_format=zarr_format)
105134
except FileNotFoundError:
106-
return False
135+
return None

src/pydantic_zarr/v2.py

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import json
34
import math
45
from collections.abc import Mapping
6+
from importlib.metadata import version
57
from typing import (
68
TYPE_CHECKING,
79
Annotated,
@@ -17,28 +19,31 @@
1719
overload,
1820
)
1921

20-
import numcodecs
2122
import numpy as np
2223
import numpy.typing as npt
2324
import zarr
2425
from numcodecs.abc import Codec
26+
from packaging.version import Version
2527
from pydantic import AfterValidator, BaseModel, field_validator, model_validator
2628
from pydantic.functional_validators import BeforeValidator
29+
from zarr.core.array import Array, AsyncArray
2730
from zarr.core.metadata import ArrayV2Metadata
31+
from zarr.core.sync import sync
2832
from zarr.errors import ContainsArrayError, ContainsGroupError
33+
from zarr.storage._common import make_store_path
2934

3035
from pydantic_zarr.core import (
3136
IncEx,
3237
StrictBase,
33-
contains_array,
34-
contains_group,
3538
ensure_key_no_path,
39+
maybe_node,
3640
model_like,
37-
stringify_dtype,
41+
parse_dtype_v2,
3842
)
3943

4044
if TYPE_CHECKING:
4145
from zarr.abc.store import Store
46+
from zarr.core.array_spec import ArrayConfigParams
4247

4348
TBaseAttr: TypeAlias = Mapping[str, object] | BaseModel
4449
TBaseItem: TypeAlias = Union["GroupSpec", "ArraySpec"]
@@ -49,7 +54,18 @@
4954
TAttr = TypeVar("TAttr", bound=TBaseAttr)
5055
TItem = TypeVar("TItem", bound=TBaseItem)
5156

52-
DtypeStr = Annotated[str, BeforeValidator(stringify_dtype)]
57+
DtypeStr = Annotated[str, BeforeValidator(parse_dtype_v2)]
58+
59+
BoolFillValue = bool
60+
IntFillValue = int
61+
# todo: introduce a type that represents hexadecimal representations of floats
62+
FloatFillValue = Literal["Infinity", "-Infinity", "NaN"] | float
63+
ComplexFillValue = tuple[FloatFillValue, FloatFillValue]
64+
RawFillValue = tuple[int, ...]
65+
66+
FillValue = (
67+
BoolFillValue | IntFillValue | FloatFillValue | ComplexFillValue | RawFillValue | str | None
68+
)
5369

5470
DimensionSeparator = Literal[".", "/"]
5571
MemoryOrder = Literal["C", "F"]
@@ -155,8 +171,8 @@ class ArraySpec(NodeSpec, Generic[TAttr]):
155171
attributes: TAttr = cast(TAttr, {})
156172
shape: tuple[int, ...]
157173
chunks: tuple[int, ...]
158-
dtype: DtypeStr
159-
fill_value: int | float | None = 0
174+
dtype: DtypeStr | list[tuple[Any, ...]]
175+
fill_value: FillValue = 0
160176
order: MemoryOrder = "C"
161177
filters: list[CodecDict] | None = None
162178
dimension_separator: Annotated[
@@ -285,7 +301,7 @@ def from_array(
285301

286302
return cls(
287303
shape=shape_actual,
288-
dtype=stringify_dtype(dtype_actual),
304+
dtype=parse_dtype_v2(dtype_actual),
289305
chunks=chunks_actual,
290306
attributes=attributes_actual,
291307
fill_value=fill_value_actual,
@@ -322,40 +338,25 @@ def from_zarr(cls, array: zarr.Array) -> Self:
322338
msg = "Array is not a Zarr format 2 array"
323339
raise TypeError(msg)
324340

325-
if len(array.compressors):
326-
compressor = array.compressors[0]
327-
if TYPE_CHECKING:
328-
# TODO: overload array.compressors in zarr-python and remove this type check
329-
assert isinstance(compressor, Codec)
330-
compressor_dict = compressor.get_config()
341+
if Version(version("zarr")) < Version("3.1.0"):
342+
from zarr.core.buffer import default_buffer_prototype
343+
344+
stored_meta = array.metadata.to_buffer_dict(prototype=default_buffer_prototype())
345+
meta_json = json.loads(stored_meta[".zarray"].to_bytes()) | {
346+
"attributes": array.attrs.asdict()
347+
}
331348
else:
332-
compressor_dict = None
349+
meta_json = array.metadata.to_dict()
333350

334-
return cls(
335-
shape=array.shape,
336-
chunks=array.chunks,
337-
dtype=str(array.dtype),
338-
# explicitly cast to numpy type and back to python
339-
# so that int 0 isn't serialized as 0.0
340-
fill_value=(
341-
array.dtype.type(array.fill_value).tolist()
342-
if array.fill_value is not None
343-
else array.fill_value
344-
),
345-
order=array.order,
346-
filters=array.filters,
347-
dimension_separator=array.metadata.dimension_separator,
348-
compressor=compressor_dict,
349-
attributes=array.attrs.asdict(),
350-
)
351+
return cls.model_validate(meta_json)
351352

352353
def to_zarr(
353354
self,
354355
store: Store,
355356
path: str,
356357
*,
357358
overwrite: bool = False,
358-
**kwargs: Any,
359+
config: ArrayConfigParams | None = None,
359360
) -> zarr.Array:
360361
"""
361362
Serialize an `ArraySpec` to a Zarr array at a specific path in a Zarr store. This operation
@@ -369,36 +370,32 @@ def to_zarr(
369370
The location of the array inside the store.
370371
overwrite : bool, default = False
371372
Whether to overwrite existing objects in storage to create the Zarr array.
372-
**kwargs : Any
373-
Additional keyword arguments are passed to `zarr.create`.
373+
config : ArrayConfigParams | None, default = None
374+
An instance of `ArrayConfigParams` that defines the runtime configuration for the array.
374375
375376
Returns
376377
-------
377378
zarr.Array
378379
A Zarr array that is structurally identical to `self`.
379380
"""
380-
spec_dict = self.model_dump()
381-
attrs = spec_dict.pop("attributes")
382-
if self.compressor is not None:
383-
spec_dict["compressor"] = numcodecs.get_codec(spec_dict["compressor"])
384-
if self.filters is not None:
385-
spec_dict["filters"] = [numcodecs.get_codec(f) for f in spec_dict["filters"]]
386-
if contains_array(store, path):
387-
extant_array = zarr.open_array(store, path=path, zarr_format=2)
388-
389-
if not self.like(extant_array):
390-
if not overwrite:
391-
raise ContainsArrayError(store, path)
381+
store_path = sync(make_store_path(store, path=path))
382+
383+
extant_node = maybe_node(store, path, zarr_format=2)
384+
if isinstance(extant_node, zarr.Array):
385+
if not self.like(extant_node) and not overwrite:
386+
raise ContainsArrayError(store, path)
392387
else:
388+
# If there's an existing array that is identical to the model, and overwrite is False,
389+
# we can just return that existing array.
393390
if not overwrite:
394-
# extant_array is read-only, so we make a new array handle that
395-
# takes **kwargs
396-
return zarr.open_array(
397-
store=extant_array.store, path=extant_array.path, zarr_format=2, **kwargs
398-
)
399-
result = zarr.create(store=store, path=path, overwrite=overwrite, **spec_dict, **kwargs)
400-
result.attrs.put(attrs)
401-
return result
391+
return extant_node
392+
if isinstance(extant_node, zarr.Group) and not overwrite:
393+
raise ContainsGroupError(store, path)
394+
395+
meta: ArrayV2Metadata = ArrayV2Metadata.from_dict(self.model_dump())
396+
async_array = AsyncArray(metadata=meta, store_path=store_path, config=config)
397+
sync(async_array._save_metadata(meta))
398+
return Array(_async_array=async_array)
402399

403400
def like(
404401
self,
@@ -568,28 +565,34 @@ def to_zarr(
568565
"""
569566
spec_dict = self.model_dump(exclude={"members": True})
570567
attrs = spec_dict.pop("attributes")
571-
if contains_group(store, path):
572-
extant_group = zarr.group(store, path=path, zarr_format=2)
573-
if not self.like(extant_group):
568+
extant_node = maybe_node(store, path, zarr_format=2)
569+
if isinstance(extant_node, zarr.Group):
570+
if not self.like(extant_node):
574571
if not overwrite:
572+
"""
575573
msg = (
576574
f"A group already exists at path {path}. "
577575
"That group is structurally dissimilar to the group you are trying to store."
578576
"Call to_zarr with overwrite=True to overwrite that group."
579577
)
580-
raise ContainsGroupError(msg)
578+
"""
579+
# Zarr's contains group error uses questionable design and doesn't take a message
580+
raise ContainsGroupError(store, path)
581581
else:
582582
if not overwrite:
583583
# if the extant group is structurally identical to self, and overwrite is false,
584584
# then just return the extant group
585-
return extant_group
585+
return extant_node
586586

587-
elif contains_array(store, path) and not overwrite:
587+
elif isinstance(extant_node, zarr.Array) and not overwrite:
588+
"""
588589
msg = (
589590
f"An array already exists at path {path}. "
590591
"Call to_zarr with overwrite=True to overwrite the array."
591592
)
592-
raise ContainsArrayError(msg)
593+
"""
594+
# Zarr's contains array error uses questionable design and doesn't take a message
595+
raise ContainsArrayError(store, path)
593596
else:
594597
zarr.create_group(store=store, overwrite=overwrite, path=path, zarr_format=2)
595598

0 commit comments

Comments
 (0)