diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 9da0059d0b..d1f3148da6 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numbers from typing import ( TYPE_CHECKING, Any, @@ -155,7 +156,10 @@ def create( fill_value: Any | None = None, ) -> Self: # np.zeros is much faster than np.full, and therefore using it when possible is better. - if fill_value is None or (isinstance(fill_value, int) and fill_value == 0): + # See https://numpy.org/doc/stable/reference/generated/numpy.isscalar.html#numpy-isscalar + # notes for why we use `numbers.Number`. + # Tehcnically `numbers.Number` need not support __eq__ hence the `ignore`. + if fill_value is None or (isinstance(fill_value, numbers.Number) and fill_value == 0): # type: ignore[comparison-overlap] return cls(np.zeros(shape=tuple(shape), dtype=dtype, order=order)) else: return cls(np.full(shape=tuple(shape), fill_value=fill_value, dtype=dtype, order=order)) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index b50e5abb67..d3fc8413b2 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Literal +from unittest import mock import numpy as np import pytest @@ -243,3 +244,29 @@ def test_empty( assert result.flags.c_contiguous # type: ignore[attr-defined] else: assert result.flags.f_contiguous # type: ignore[attr-defined] + + +@pytest.mark.parametrize("dtype", [np.int8, np.uint16, np.float32, int, float]) +@pytest.mark.parametrize("fill_value", [None, 0, 1]) +def test_no_full_with_zeros( + dtype: type[np.number[np.typing.NBitBase] | float], + fill_value: None | float, +) -> None: + """Ensure that fill value of 0 (or None with a numeric dtype) does not trigger np.full, and instead triggers np.zeros""" + # full never called with fill 0 + if fill_value == 0: + with mock.patch("numpy.full", side_effect=RuntimeError): + cpu.buffer_prototype.nd_buffer.create( + shape=(10,), dtype=dtype, fill_value=dtype(fill_value) + ) + # full or zeros called appropriately based on fill value + with mock.patch( + "numpy.zeros" if fill_value == 0 or fill_value is None else "numpy.full", + side_effect=RuntimeError("called"), + ): + with pytest.raises(RuntimeError, match=r"called"): + cpu.buffer_prototype.nd_buffer.create( + shape=(10,), + dtype=dtype, + fill_value=dtype(fill_value) if fill_value is not None else fill_value, + )