|
5 | 5 | from typing import TYPE_CHECKING, Any, TypeVar |
6 | 6 | from warnings import warn |
7 | 7 |
|
| 8 | +import numpy as np |
| 9 | + |
8 | 10 | from zarr.abc.codec import ( |
9 | 11 | ArrayArrayCodec, |
10 | 12 | ArrayBytesCodec, |
|
19 | 21 | from zarr.core.indexing import SelectorTuple, is_scalar |
20 | 22 | from zarr.errors import ZarrUserWarning |
21 | 23 | from zarr.registry import register_pipeline |
| 24 | +from zarr.core.buffer import NDBuffer |
22 | 25 |
|
23 | 26 | if TYPE_CHECKING: |
24 | 27 | from collections.abc import Iterable, Iterator |
25 | 28 | from typing import Self |
26 | 29 |
|
27 | 30 | from zarr.abc.store import ByteGetter, ByteSetter |
28 | 31 | from zarr.core.array_spec import ArraySpec |
29 | | - from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer |
| 32 | + from zarr.core.buffer import Buffer, BufferPrototype |
30 | 33 | from zarr.core.chunk_grids import ChunkGrid |
31 | 34 | from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType |
32 | 35 |
|
@@ -413,18 +416,21 @@ async def _read_key( |
413 | 416 | if chunk_array is None: |
414 | 417 | chunk_array_batch.append(None) # type: ignore[unreachable] |
415 | 418 | else: |
416 | | - # The operation array_equal operation below effectively will force the array |
417 | | - # into memory. |
418 | | - # if the result is useful, we want to avoid reading it twice |
419 | | - # from a potentially lazy operation. So we cache it here. |
420 | | - # If the result is not useful, we leave it for the garbage collector. |
421 | | - chunk_array._data = chunk_array.as_numpy_array() |
422 | | - if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( |
423 | | - fill_value_or_default(chunk_spec) |
424 | | - ): |
425 | | - chunk_array_batch.append(None) |
426 | | - else: |
427 | | - chunk_array_batch.append(chunk_array) |
| 419 | + if not chunk_spec.config.write_empty_chunks: |
| 420 | + # The operation array_equal operation below effectively will force the array |
| 421 | + # into memory. |
| 422 | + # if the result is useful, we want to avoid reading it twice |
| 423 | + # from a potentially lazy operation. So we cache it here. |
| 424 | + # If the result is not useful, we leave it for the garbage collector. |
| 425 | + # We optimize this operation for the case that the GPU |
| 426 | + if not hasattr(chunk_array._data, '__cuda_array_interface__'): |
| 427 | + chunk_array = NDBuffer(np.asarray(chunk_array._data)) |
| 428 | + |
| 429 | + if chunk_array.all_equal( |
| 430 | + fill_value_or_default(chunk_spec) |
| 431 | + ): |
| 432 | + chunk_array = None |
| 433 | + chunk_array_batch.append(chunk_array) |
428 | 434 |
|
429 | 435 | chunk_bytes_batch = await self.encode_batch( |
430 | 436 | [ |
|
0 commit comments