Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions bench/write_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import itertools
import os.path
import shutil
import sys
import tempfile
import timeit

import line_profiler
import numpy as np

import zarr
import zarr.codecs
import zarr.codecs.sharding

if __name__ == "__main__":
sys.path.insert(0, "..")

# setup
with tempfile.TemporaryDirectory() as path:

ndim = 3
opt = {
'shape': [1024]*ndim,
'chunks': [128]*ndim,
'shards': [512]*ndim,
'dtype': np.float64,
}

store = zarr.storage.LocalStore(path)
z = zarr.create_array(store, **opt)
print(z)

def cleanup() -> None:
for elem in os.listdir(path):
elem = os.path.join(path, elem)
if not elem.endswith(".json"):
if os.path.isdir(elem):
shutil.rmtree(elem)
else:
os.remove(elem)

def write() -> None:
wchunk = [512]*ndim
nwchunks = [n//s for n, s in zip(opt['shape'], wchunk, strict=True)]
for shard in itertools.product(*(range(n) for n in nwchunks)):
slicer = tuple(
slice(i*n, (i+1)*n)
for i, n in zip(shard, wchunk, strict=True)
)
d = np.random.rand(*wchunk).astype(opt['dtype'])
z[slicer] = d

print("*" * 79)

# time
vars = {"write": write, "cleanup": cleanup, "z": z, "opt": opt}
t = timeit.repeat("write()", "cleanup()", repeat=2, number=1, globals=vars)
print(t)
print(min(t))
print(z)

# profile
# f = zarr.codecs.sharding.ShardingCodec._encode_partial_single
# profile = line_profiler.LineProfiler(f)
# profile.run("write()")
# profile.print_stats()
4 changes: 2 additions & 2 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def create_empty(
buffer_prototype = default_buffer_prototype()
index = _ShardIndex.create_empty(chunks_per_shard)
obj = cls()
obj.buf = buffer_prototype.buffer.create_zero_length()
obj.buf = buffer_prototype.buffer.Delayed.create_zero_length()
obj.index = index
return obj

Expand Down Expand Up @@ -252,7 +252,7 @@ def create_empty(
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
obj = cls()
obj.buf = buffer_prototype.buffer.create_zero_length()
obj.buf = buffer_prototype.buffer.Delayed.create_zero_length()
obj.index = _ShardIndex.create_empty(chunks_per_shard)
return obj

Expand Down
147 changes: 147 additions & 0 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,153 @@ class BufferPrototype(NamedTuple):
nd_buffer: type[NDBuffer]


class DelayedBuffer(Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl: type
_concatenate: callable

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
if array is None:
self._data_list = []
elif isinstance(array, list):
self._data_list = list(array)
else:
self._data_list = [array]
for array in self._data_list:
if array.ndim != 1:
raise ValueError("array: only 1-dim allowed")
if array.dtype != np.dtype("b"):
raise ValueError("array: only byte dtype allowed")

@property
def _data(self) -> npt.NDArray[Any]:
return type(self)._concatenate(self._data_list)

@classmethod
def from_buffer(cls, buffer: Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)
else:
return cls(buffer._data)

def __add__(self, other: Buffer) -> Self:
if isinstance(other, self.__class__):
return self.__class__(self._data_list + other._data_list)
else:
return self.__class__(self._data_list + [other._data])

def __radd__(self, other: Buffer) -> Self:
if isinstance(other, self.__class__):
return self.__class__(other._data_list + self._data_list)
else:
return self.__class__([other._data] + self._data_list)

def __len__(self) -> int:
return sum(map(len, self._data_list))

def __getitem__(self, key: slice) -> Self:
check_item_key_is_1d_contiguous(key)
start, stop = key.start, key.stop
this_len = len(self)
if start is None:
start = 0
if start < 0:
start = this_len + start
if stop is None:
stop = this_len
if stop < 0:
stop = this_len + stop
if stop > this_len:
stop = this_len
if stop <= start:
return Buffer.from_buffer(b'')

new_list = []
offset = 0
found_last = False
for chunk in self._data_list:
chunk_size = len(chunk)
skip = False
if 0 <= start - offset < chunk_size:
# first chunk
if stop - offset <= chunk_size:
# also last chunk
chunk = chunk[start-offset:stop-offset]
found_last = True
else:
chunk = chunk[start-offset:]
elif 0 <= stop - offset <= chunk_size:
# last chunk
chunk = chunk[:stop-offset]
found_last = True
elif chunk_size <= start - offset:
# before first chunk
skip = True
else:
# middle chunk
pass

if not skip:
new_list.append(chunk)
if found_last:
break
offset += chunk_size
assert sum(map(len, new_list)) == stop - start
return self.__class__(new_list)

def __setitem__(self, key: slice, value: Any) -> None:
# This assumes that `value` is a broadcasted array
check_item_key_is_1d_contiguous(key)
start, stop = key.start, key.stop
if start is None:
start = 0
if start < 0:
start = len(self) + start
if stop is None:
stop = len(self)
if stop < 0:
stop = len(self) + stop
if stop <= start:
return

offset = 0
found_last = False
value = memoryview(np.asanyarray(value))
for chunk in self._data_list:
chunk_size = len(chunk)
skip = False
if 0 <= start - offset < chunk_size:
# first chunk
if stop - offset <= chunk_size:
# also last chunk
chunk = chunk[start-offset:stop-offset]
found_last = True
else:
chunk = chunk[start-offset:]
elif 0 <= stop - offset <= chunk_size:
# last chunk
chunk = chunk[:stop-offset]
found_last = True
elif chunk_size <= start - offset:
# before first chunk
skip = True
else:
# middle chunk
pass

if not skip:
chunk[:] = value[:len(chunk)]
value = value[len(chunk):]
if len(value) == 0:
# nothing left to write
break
if found_last:
break
offset += chunk_size


# The default buffer prototype used throughout the Zarr codebase.
def default_buffer_prototype() -> BufferPrototype:
from zarr.registry import (
Expand Down
33 changes: 33 additions & 0 deletions src/zarr/core/buffer/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ def __setitem__(self, key: Any, value: Any) -> None:
self._data.__setitem__(key, value)


class DelayedBuffer(core.DelayedBuffer, Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl = Buffer
_concatenate = np.concatenate

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
core.DelayedBuffer.__init__(self, array)
self._data_list = list(map(np.asanyarray, self._data_list))

@classmethod
def create_zero_length(cls) -> Self:
return cls(np.array([], dtype="b"))

@classmethod
def from_buffer(cls, buffer: core.Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)
else:
return cls(buffer._data)

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
return cls(np.asarray(bytes_like, dtype="b"))

def as_numpy_array(self) -> npt.NDArray[Any]:
return np.asanyarray(self._data)


Buffer.Delayed = DelayedBuffer


def as_numpy_array_wrapper(
func: Callable[[npt.NDArray[Any]], bytes], buf: core.Buffer, prototype: core.BufferPrototype
) -> core.Buffer:
Expand Down
33 changes: 33 additions & 0 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,39 @@ def __setitem__(self, key: Any, value: Any) -> None:
self._data.__setitem__(key, value)


class DelayedBuffer(core.DelayedBuffer, Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl = Buffer
_concatenate = getattr(cp, 'concatenate', None)

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
core.DelayedBuffer.__init__(self, array)
self._data_list = list(map(cp.asarray, self._data_list))

@classmethod
def create_zero_length(cls) -> Self:
return cls(np.array([], dtype="b"))

@classmethod
def from_buffer(cls, buffer: core.Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)
else:
return cls(buffer._data)

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
return cls(np.asarray(bytes_like, dtype="b"))

def as_numpy_array(self) -> npt.NDArray[Any]:
return np.asanyarray(self._data)


Buffer.Delayed = DelayedBuffer


buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)

register_buffer(Buffer, qualname="zarr.buffer.gpu.Buffer")
Expand Down
Loading