Skip to content

Commit 7a59d84

Browse files
committed
Fix JSON encoding of complex fill values
We were not replacing NaNs and Infs with the string versions.
1 parent a9d6d74 commit 7a59d84

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def default(self, o: object) -> Any:
149149
if isinstance(out, complex):
150150
# python complex types are not JSON serializable, so we use the
151151
# serialization defined in the zarr v3 spec
152-
return [out.real, out.imag]
152+
return _replace_special_floats([out.real, out.imag])
153153
elif np.isnan(out):
154154
return "NaN"
155155
elif np.isinf(out):

tests/test_array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import math
13
import pickle
24
from itertools import accumulate
35
from typing import Any, Literal
@@ -9,6 +11,7 @@
911
from zarr import Array, AsyncArray, Group
1012
from zarr.codecs import BytesCodec, VLenBytesCodec
1113
from zarr.core.array import chunks_initialized
14+
from zarr.core.buffer import default_buffer_prototype
1215
from zarr.core.buffer.cpu import NDBuffer
1316
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
1417
from zarr.core.group import AsyncGroup
@@ -436,3 +439,23 @@ def test_array_create_order(
436439
assert vals.flags.f_contiguous
437440
else:
438441
raise AssertionError
442+
443+
444+
@pytest.mark.parametrize(
445+
("fill_value", "expected"),
446+
[
447+
(np.nan * 1j, ["NaN", "NaN"]),
448+
(np.nan, ["NaN", 0.0]),
449+
(np.inf, ["Infinity", 0.0]),
450+
(np.inf * 1j, ["NaN", "Infinity"]),
451+
(-np.inf, ["-Infinity", 0.0]),
452+
(math.inf, ["Infinity", 0.0]),
453+
],
454+
)
455+
async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None:
456+
store = MemoryStore({}, mode="w")
457+
Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value)
458+
content = await store.get("zarr.json", prototype=default_buffer_prototype())
459+
assert content is not None
460+
actual = json.loads(content.to_bytes())
461+
assert actual["fill_value"] == expected

0 commit comments

Comments
 (0)