Skip to content

Commit c75d27d

Browse files
committed
Restore optimizations for NDBuffer.all_equal
Zarr 3.x has some performance regressions for certain write workloads (writing large chunks with floating point dtype). This change modifies the implementation of `NDBuffer.all_equal` to be the same logic as Zarr 2.x's `zarr.util.all_equals`, which contains a number of important optimizations. A few mechanical changes were made to accomodate that the subroutine is now a method of `NDBuffer` rather than function. This change is most impactful when writing large floating point chunks as the implementation of ```python np.all(np.isnan(self._data)) ``` is significantly more efficient than calling ```python _data, other = np.broadcast(self.data, np.nan) np.array_equal(_data, other, equal_nan=True)) ``` since `np.broadcast` requires potentially a large allocation -- the size of `self.data -- and then np.array_equal needs to fetch double the number of cache lines. On EC2 r7i.2xlarge: ``` In [20]: data = np.random.rand(512, 512, 8) In [21]: %timeit np.all(np.isnan(data)) 596 μs ± 179 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each) In [22]: %%timeit ...: data_, other = np.broadcast_arrays(data, np.nan) ...: np.array_equal(data_, other, equal_nan=True) ...: ...: 2.66 ms ± 953 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` (Both numbers are faster on M3 Max but similar slowdown). With low-latency stores (e.g. local SSD), this results in double-digit % speed-ups for the workload referenced in the Zarr V3 blog post: ``` import numpy as np import zarr za = zarr.create_array( /tmp/foo.zarr", shape=(512, 512, 512), chunks=(512, 512, 8), dtype=np.float64, overwrite=True, ) arr = np.random.rand(512, 512, 512) za[:] = arr ``` For higher latency stores, improvement is still dramatic (10%+) when chunks have high compression ratios (e.g. np.ones). For arrays larger than 1 GB, improvement is even more pronounced.
1 parent 31d377b commit c75d27d

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/zarr/core/buffer/core.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -460,18 +460,33 @@ def __len__(self) -> int:
460460
def __repr__(self) -> str:
461461
return f"<NDBuffer shape={self.shape} dtype={self.dtype} {self._data!r}>"
462462

463-
def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
464-
"""Compare to `other` using np.array_equal."""
465-
if other is None:
463+
def all_equal(self, value: Any, equal_nan: bool = True) -> bool:
464+
if value is None:
466465
# Handle None fill_value for Zarr V2
467466
return False
468-
# use array_equal to obtain equal_nan=True functionality
469-
# Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value
470-
# every single time we have to write data?
471-
_data, other = np.broadcast_arrays(self._data, other)
472-
return np.array_equal(
473-
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False
474-
)
467+
468+
if not value:
469+
# If `value` is falsey, then just 1 truthy value in `array`
470+
# is sufficient to return False. We assume here that np.any is
471+
# optimized to return on the first truthy value in `array`.
472+
try:
473+
return not np.any(self._data)
474+
except (TypeError, ValueError): # pragma: no cover
475+
pass
476+
477+
if np.issubdtype(self._data.dtype, np.object_):
478+
# We have to flatten the result of np.equal to handle outputs like
479+
# [np.array([True,True]), True, True]
480+
return all(np.equal(value, self._data, dtype=self._data.dtype).flatten())
481+
else:
482+
# Numpy errors if you call np.isnan on custom dtypes, so ensure
483+
# we are working with floats before calling isnan
484+
if np.issubdtype(self._data.dtype, np.floating) and np.isnan(value):
485+
return np.all(np.isnan(self._data))
486+
else:
487+
# Using == raises warnings from numpy deprecated pattern, but
488+
# using np.equal() raises type errors for structured dtypes...
489+
return np.all(value == self._data)
475490

476491
def fill(self, value: Any) -> None:
477492
self._data.fill(value)

0 commit comments

Comments
 (0)