Skip to content

Commit afb8fff

Browse files
committed
Update for improved GPU support
1 parent 90ab6e1 commit afb8fff

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import TYPE_CHECKING, Any, TypeVar
66
from warnings import warn
77

8+
import numpy as np
9+
810
from zarr.abc.codec import (
911
ArrayArrayCodec,
1012
ArrayBytesCodec,
@@ -19,14 +21,15 @@
1921
from zarr.core.indexing import SelectorTuple, is_scalar
2022
from zarr.errors import ZarrUserWarning
2123
from zarr.registry import register_pipeline
24+
from zarr.core.buffer import NDBuffer
2225

2326
if TYPE_CHECKING:
2427
from collections.abc import Iterable, Iterator
2528
from typing import Self
2629

2730
from zarr.abc.store import ByteGetter, ByteSetter
2831
from zarr.core.array_spec import ArraySpec
29-
from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer
32+
from zarr.core.buffer import Buffer, BufferPrototype
3033
from zarr.core.chunk_grids import ChunkGrid
3134
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
3235

@@ -413,18 +416,21 @@ async def _read_key(
413416
if chunk_array is None:
414417
chunk_array_batch.append(None) # type: ignore[unreachable]
415418
else:
416-
# The operation array_equal operation below effectively will force the array
417-
# into memory.
418-
# if the result is useful, we want to avoid reading it twice
419-
# from a potentially lazy operation. So we cache it here.
420-
# If the result is not useful, we leave it for the garbage collector.
421-
chunk_array._data = chunk_array.as_numpy_array()
422-
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
423-
fill_value_or_default(chunk_spec)
424-
):
425-
chunk_array_batch.append(None)
426-
else:
427-
chunk_array_batch.append(chunk_array)
419+
if not chunk_spec.config.write_empty_chunks:
420+
# The operation array_equal operation below effectively will force the array
421+
# into memory.
422+
# if the result is useful, we want to avoid reading it twice
423+
# from a potentially lazy operation. So we cache it here.
424+
# If the result is not useful, we leave it for the garbage collector.
425+
# We optimize this operation for the case that the GPU
426+
if not hasattr(chunk_array._data, '__cuda_array_interface__'):
427+
chunk_array = NDBuffer(np.asarray(chunk_array._data))
428+
429+
if chunk_array.all_equal(
430+
fill_value_or_default(chunk_spec)
431+
):
432+
chunk_array = None
433+
chunk_array_batch.append(chunk_array)
428434

429435
chunk_bytes_batch = await self.encode_batch(
430436
[

0 commit comments

Comments
 (0)