Skip to content

Commit 1a9158a

Browse files
committed
Refactor DelayedBuffer + implement efficient __getitem__
1 parent 09bc21d commit 1a9158a

File tree

3 files changed

+92
-64
lines changed

3 files changed

+92
-64
lines changed

src/zarr/core/buffer/core.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,89 @@ class BufferPrototype(NamedTuple):
502502
nd_buffer: type[NDBuffer]
503503

504504

505+
class DelayedBuffer(Buffer):
506+
"""
507+
A Buffer that is the virtual concatenation of other buffers.
508+
"""
509+
_BufferImpl: type
510+
_concatenate: callable
511+
512+
def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
513+
if array is None:
514+
self._data_list = []
515+
elif isinstance(array, list):
516+
self._data_list = list(array)
517+
else:
518+
self._data_list = [array]
519+
for array in self._data_list:
520+
if array.ndim != 1:
521+
raise ValueError("array: only 1-dim allowed")
522+
if array.dtype != np.dtype("b"):
523+
raise ValueError("array: only byte dtype allowed")
524+
525+
@property
526+
def _data(self) -> npt.NDArray[Any]:
527+
return type(self)._concatenate(self._data_list)
528+
529+
@classmethod
530+
def from_buffer(cls, buffer: Buffer) -> Self:
531+
if isinstance(buffer, cls):
532+
return cls(buffer._data_list)
533+
else:
534+
return cls(buffer._data)
535+
536+
def __add__(self, other: Buffer) -> Self:
537+
if isinstance(other, self.__class__):
538+
return self.__class__(self._data_list + other._data_list)
539+
else:
540+
return self.__class__(self._data_list + [other._data])
541+
542+
def __radd__(self, other: Buffer) -> Self:
543+
if isinstance(other, self.__class__):
544+
return self.__class__(other._data_list + self._data_list)
545+
else:
546+
return self.__class__([other._data] + self._data_list)
547+
548+
def __len__(self) -> int:
549+
return sum(map(len, self._data_list))
550+
551+
def __getitem__(self, key: slice) -> Self:
552+
check_item_key_is_1d_contiguous(key)
553+
start, stop = key.start, key.stop
554+
if start is None:
555+
start = 0
556+
if stop is None:
557+
stop = len(self)
558+
new_list = []
559+
offset = 0
560+
found_last = False
561+
for chunk in self._data_list:
562+
chunk_size = len(chunk)
563+
skip = False
564+
if offset <= start < offset + chunk_size:
565+
# first chunk
566+
if stop <= offset + chunk_size:
567+
# also last chunk
568+
chunk = chunk[start-offset:stop-offset]
569+
found_last = True
570+
else:
571+
chunk = chunk[start-offset:]
572+
elif offset <= stop <= offset + chunk_size:
573+
# last chunk
574+
chunk = chunk[:stop-offset]
575+
found_last = True
576+
elif offset + chunk_size <= start:
577+
skip = True
578+
579+
if not skip:
580+
new_list.append(chunk)
581+
if found_last:
582+
break
583+
offset += chunk_size
584+
assert sum(map(len, new_list)) == stop - start
585+
return self.__class__(new_list)
586+
587+
505588
# The default buffer prototype used throughout the Zarr codebase.
506589
def default_buffer_prototype() -> BufferPrototype:
507590
from zarr.registry import (

src/zarr/core/buffer/cpu.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -185,43 +185,16 @@ def __setitem__(self, key: Any, value: Any) -> None:
185185
self._data.__setitem__(key, value)
186186

187187

188-
class DelayedBuffer(Buffer):
188+
class DelayedBuffer(core.DelayedBuffer, Buffer):
189189
"""
190190
A Buffer that is the virtual concatenation of other buffers.
191191
"""
192+
_BufferImpl = Buffer
193+
_concatenate = np.concatenate
192194

193195
def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
194-
if array is None:
195-
self._data_list = []
196-
elif isinstance(array, list):
197-
self._data_list = list(array)
198-
else:
199-
self._data_list = [array]
200-
for array in self._data_list:
201-
if array.ndim != 1:
202-
raise ValueError("array: only 1-dim allowed")
203-
if array.dtype != np.dtype("b"):
204-
raise ValueError("array: only byte dtype allowed")
205-
206-
@property
207-
def _data(self) -> npt.NDArray[Any]:
208-
return np.concatenate(self._data_list)
209-
210-
@classmethod
211-
def from_buffer(cls, buffer: core.Buffer) -> Self:
212-
if isinstance(buffer, cls):
213-
return cls(buffer._data_list)
214-
else:
215-
return cls(buffer._data)
216-
217-
def __add__(self, other: core.Buffer) -> Self:
218-
if isinstance(other, self.__class__):
219-
return self.__class__(self._data_list + other._data_list)
220-
else:
221-
return self.__class__(self._data_list + [other._data])
222-
223-
def __len__(self) -> int:
224-
return sum(map(len, self._data_list))
196+
core.DelayedBuffer.__init__(self, array)
197+
self._data_list = list(map(np.asanyarray, self._data_list))
225198

226199

227200
Buffer.Delayed = DelayedBuffer

src/zarr/core/buffer/gpu.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -218,45 +218,17 @@ def __setitem__(self, key: Any, value: Any) -> None:
218218
self._data.__setitem__(key, value)
219219

220220

221-
class DelayedBuffer(Buffer):
221+
class DelayedBuffer(core.DelayedBuffer, Buffer):
222222
"""
223223
A Buffer that is the virtual concatenation of other buffers.
224224
"""
225+
_BufferImpl = Buffer
226+
_concatenate = getattr(cp, 'concatenate', None)
225227

226228
def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
227-
if array is None:
228-
self._data_list = []
229-
elif isinstance(array, list):
230-
self._data_list = list(array)
231-
else:
232-
self._data_list = [array]
233-
for array in self._data_list:
234-
if array.ndim != 1:
235-
raise ValueError("array: only 1-dim allowed")
236-
if array.dtype != np.dtype("b"):
237-
raise ValueError("array: only byte dtype allowed")
229+
core.DelayedBuffer.__init__(self, array)
238230
self._data_list = list(map(cp.asarray, self._data_list))
239231

240-
@property
241-
def _data(self) -> npt.NDArray[Any]:
242-
return cp.concatenate(self._data_list)
243-
244-
@classmethod
245-
def from_buffer(cls, buffer: core.Buffer) -> Self:
246-
if isinstance(buffer, cls):
247-
return cls(buffer._data_list)
248-
else:
249-
return cls(buffer._data)
250-
251-
def __add__(self, other: core.Buffer) -> Self:
252-
if isinstance(other, self.__class__):
253-
return self.__class__(self._data_list + other._data_list)
254-
else:
255-
return self.__class__(self._data_list + [other._data])
256-
257-
def __len__(self) -> int:
258-
return sum(map(len, self._data_list))
259-
260232

261233
Buffer.Delayed = DelayedBuffer
262234

0 commit comments

Comments
 (0)