Skip to content

Commit 4082888

Browse files
committed
Sharding: do not merge new/old if old is empty + avoid sequential buffer concatenation
1 parent 50abf3d commit 4082888

File tree

3 files changed

+95
-9
lines changed

3 files changed

+95
-9
lines changed

src/zarr/codecs/sharding.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def create_empty(
251251
if buffer_prototype is None:
252252
buffer_prototype = default_buffer_prototype()
253253
obj = cls()
254-
obj.buf = buffer_prototype.buffer.create_zero_length()
254+
obj.buf = buffer_prototype.buffer.Delayed.create_zero_length()
255255
obj.index = _ShardIndex.create_empty(chunks_per_shard)
256256
return obj
257257

@@ -585,15 +585,16 @@ async def _encode_partial_single(
585585
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
586586
chunk_spec = self._get_chunk_spec(shard_spec)
587587

588-
shard_dict = _MergingShardBuilder(
589-
await self._load_full_shard_maybe(
590-
byte_getter=byte_setter,
591-
prototype=chunk_spec.prototype,
592-
chunks_per_shard=chunks_per_shard,
593-
)
594-
or _ShardReader.create_empty(chunks_per_shard),
595-
_ShardBuilder.create_empty(chunks_per_shard),
588+
shard_read = await self._load_full_shard_maybe(
589+
byte_getter=byte_setter,
590+
prototype=chunk_spec.prototype,
591+
chunks_per_shard=chunks_per_shard,
596592
)
593+
shard_build = _ShardBuilder.create_empty(chunks_per_shard)
594+
if shard_read:
595+
shard_dict = _MergingShardBuilder(shard_read, shard_build)
596+
else:
597+
shard_dict = shard_build
597598

598599
indexer = list(
599600
get_indexer(

src/zarr/core/buffer/cpu.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,48 @@ def __setitem__(self, key: Any, value: Any) -> None:
185185
self._data.__setitem__(key, value)
186186

187187

188+
class DelayedBuffer(Buffer):
189+
"""
190+
A Buffer that is the virtual concatenation of other buffers.
191+
"""
192+
193+
def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
194+
if array is None:
195+
self._data_list = []
196+
elif isinstance(array, list):
197+
self._data_list = list(array)
198+
else:
199+
self._data_list = [array]
200+
for array in self._data_list:
201+
if array.ndim != 1:
202+
raise ValueError("array: only 1-dim allowed")
203+
if array.dtype != np.dtype("b"):
204+
raise ValueError("array: only byte dtype allowed")
205+
206+
@property
207+
def _data(self) -> npt.NDArray[Any]:
208+
return np.concatenate(self._data_list)
209+
210+
@classmethod
211+
def from_buffer(cls, buffer: core.Buffer) -> Self:
212+
if isinstance(buffer, cls):
213+
return cls(buffer._data_list)
214+
else:
215+
return cls(buffer._data)
216+
217+
def __add__(self, other: core.Buffer) -> Self:
218+
if isinstance(other, self.__class__):
219+
return self.__class__(self._data_list + other._data_list)
220+
else:
221+
return self.__class__(self._data_list + [other._data])
222+
223+
def __len__(self) -> int:
224+
return sum(map(len, self._data_list))
225+
226+
227+
Buffer.Delayed = DelayedBuffer
228+
229+
188230
def as_numpy_array_wrapper(
189231
func: Callable[[npt.NDArray[Any]], bytes], buf: core.Buffer, prototype: core.BufferPrototype
190232
) -> core.Buffer:

src/zarr/core/buffer/gpu.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,49 @@ def __setitem__(self, key: Any, value: Any) -> None:
218218
self._data.__setitem__(key, value)
219219

220220

221+
class DelayedBuffer(Buffer):
222+
"""
223+
A Buffer that is the virtual concatenation of other buffers.
224+
"""
225+
226+
def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
227+
if array is None:
228+
self._data_list = []
229+
elif isinstance(array, list):
230+
self._data_list = list(array)
231+
else:
232+
self._data_list = [array]
233+
for array in self._data_list:
234+
if array.ndim != 1:
235+
raise ValueError("array: only 1-dim allowed")
236+
if array.dtype != np.dtype("b"):
237+
raise ValueError("array: only byte dtype allowed")
238+
self._data_list = list(map(cp.asarray, self._data_list))
239+
240+
@property
241+
def _data(self) -> npt.NDArray[Any]:
242+
return cp.concatenate(self._data_list)
243+
244+
@classmethod
245+
def from_buffer(cls, buffer: core.Buffer) -> Self:
246+
if isinstance(buffer, cls):
247+
return cls(buffer._data_list)
248+
else:
249+
return cls(buffer._data)
250+
251+
def __add__(self, other: core.Buffer) -> Self:
252+
if isinstance(other, self.__class__):
253+
return self.__class__(self._data_list + other._data_list)
254+
else:
255+
return self.__class__(self._data_list + [other._data])
256+
257+
def __len__(self) -> int:
258+
return sum(map(len, self._data_list))
259+
260+
261+
Buffer.Delayed = DelayedBuffer
262+
263+
221264
buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
222265

223266
register_buffer(Buffer)

0 commit comments

Comments
 (0)