diff --git a/sharding_test.py b/sharding_test.py new file mode 100644 index 0000000..0c1b3c3 --- /dev/null +++ b/sharding_test.py @@ -0,0 +1,88 @@ +import json +import os +import shutil + +import zarrita + + +shutil.rmtree("sharding_test.zr3", ignore_errors=True) +h = zarrita.create_hierarchy("sharding_test.zr3") +a = h.create_array( + path="testarray", + shape=(20, 3), + dtype="float64", + chunk_shape=(3, 2), + sharding={"chunks_per_shard": (2, 2)}, +) + +a[:10, :] = 42 +a[15, 1] = 389 +a[19, 2] = 1 +a[0, 1] = -4.2 + +assert a.store._chunks_per_shard == (2, 2) +assert a[15, 1] == 389 +assert a[19, 2] == 1 +assert a[0, 1] == -4.2 +assert a[0, 0] == 42 + +array_json = a.store["meta/root/testarray.array.json"].decode() + +print(array_json) +# { +# "shape": [ +# 20, +# 3 +# ], +# "data_type": " 0: + print(" ", root.ljust(40), *sorted(files)) +print("UNDERLYING STORE", sorted(i.rsplit("c")[-1] for i in a.store._store if i.startswith("data"))) +print("STORE", sorted(i.rsplit("c")[-1] for i in a.store if i.startswith("data"))) +# ONDISK +# sharding_test.zr3 zarr.json +# sharding_test.zr3/data/root/testarray/c0 0 +# sharding_test.zr3/data/root/testarray/c1 0 +# sharding_test.zr3/data/root/testarray/c2 0 +# sharding_test.zr3/data/root/testarray/c3 0 +# sharding_test.zr3/meta/root testarray.array.json +# UNDERLYING STORE ['0/0', '1/0', '2/0', '3/0'] +# STORE ['0/0', '0/1', '1/0', '1/1', '2/0', '2/1', '3/0', '3/1', '5/0', '6/1'] + +index_bytes = a.store._store["data/root/testarray/c0/0"][-2*2*16:] +print("INDEX 0.0", [int.from_bytes(index_bytes[i:i+8], byteorder="little") for i in range(0, len(index_bytes), 8)]) +# INDEX 0.0 [0, 48, 48, 48, 96, 48, 144, 48] + + +a_reopened = zarrita.get_hierarchy("sharding_test.zr3").get_array("testarray") +assert a_reopened.store._chunks_per_shard == (2, 2) +assert a_reopened[15, 1] == 389 +assert a_reopened[19, 2] == 1 +assert a_reopened[0, 1] == -4.2 +assert a_reopened[0, 0] == 42 diff --git a/zarrita.py b/zarrita.py index a10214b..37ed72b 100644 --- a/zarrita.py +++ b/zarrita.py @@ -3,10 +3,11 @@ import json import numbers import itertools +import functools import math import re from collections.abc import Mapping, MutableMapping -from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple +from typing import Iterator, Union, Optional, Tuple, Any, List, Dict, NamedTuple, Iterable, Type # third-party dependencies @@ -170,6 +171,18 @@ def _check_compressor(compressor: Optional[Codec]) -> None: assert compressor is None or isinstance(compressor, Codec) +def _check_sharding(sharding: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if sharding is None: + return None + if "format" not in sharding: + sharding["format"] = "indexed" + assert sharding["format"] in SHARDED_STORES, ( + f"Shard format {sharding['format']} is not supported, " + + f"use one of {list(SHARDED_STORES)}" + ) + return sharding + + def _encode_codec_metadata(codec: Codec) -> Optional[Mapping]: if codec is None: return None @@ -265,7 +278,8 @@ def create_array(self, chunk_separator: str = "/", compressor: Optional[Codec] = None, fill_value: Any = None, - attrs: Optional[Mapping] = None) -> Array: + attrs: Optional[Mapping] = None, + sharding: Optional[Dict[str, Any]] = None) -> Array: # sanity checks path = _check_path(path) @@ -274,6 +288,7 @@ def create_array(self, chunk_shape = _check_chunk_shape(chunk_shape, shape) _check_compressor(compressor) attrs = _check_attrs(attrs) + sharding = _check_sharding(sharding) # encode data type if dtype == np.bool_: @@ -297,6 +312,8 @@ def create_array(self, ) if compressor is not None: meta["compressor"] = _encode_codec_metadata(compressor) + if sharding is not None: + meta["sharding"] = sharding # serialise and store metadata document meta_doc = _json_encode_object(meta) @@ -307,7 +324,8 @@ def create_array(self, array = Array(store=self.store, path=path, owner=self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, - fill_value=fill_value, attrs=attrs) + fill_value=fill_value, attrs=attrs, + sharding=sharding) return array @@ -341,12 +359,13 @@ def get_array(self, path: str) -> Array: if spec["must_understand"]: raise NotImplementedError(spec) attrs = meta["attributes"] + sharding = meta.get("sharding", None) # instantiate array a = Array(store=self.store, path=path, owner=self, shape=shape, dtype=dtype, chunk_shape=chunk_shape, chunk_separator=chunk_separator, compressor=compressor, - fill_value=fill_value, attrs=attrs) + fill_value=fill_value, attrs=attrs, sharding=sharding) return a @@ -587,7 +606,15 @@ def __init__(self, chunk_separator: str, compressor: Optional[Codec], fill_value: Any = None, - attrs: Optional[Mapping] = None): + attrs: Optional[Mapping] = None, + sharding: Optional[Dict[str, Any]] = None, + ): + if sharding is not None: + store = SHARDED_STORES[sharding["format"]]( # type: ignore + store=store, + chunk_separator=chunk_separator, + **sharding, + ) super().__init__(store=store, path=path, owner=owner) self.shape = shape self.dtype = dtype @@ -771,7 +798,7 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value): encoded_chunk_data = self._encode_chunk(chunk) # store - self.store[chunk_key] = encoded_chunk_data + self.store[chunk_key] = encoded_chunk_data.tobytes() def _encode_chunk(self, chunk): @@ -1146,3 +1173,176 @@ def __repr__(self) -> str: if isinstance(protocol, tuple): protocol = protocol[-1] return f"{protocol}://{self.root}" + + +MAX_UINT_64 = 2 ** 64 - 1 + + +def _is_data_key(key: str) -> bool: + return key.startswith("data/root") + + +class _ShardIndex(NamedTuple): + store: "IndexedShardedStore" + offsets_and_lengths: np.ndarray # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) + + def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]: + return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.store._chunks_per_shard)) + + def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]: + localized_chunk = self.__localize_chunk__(chunk) + chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk] + if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): + return None + else: + return slice(chunk_start, chunk_start + chunk_len) + + def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None: + localized_chunk = self.__localize_chunk__(chunk) + if chunk_slice is None: + self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64) + else: + self.offsets_and_lengths[localized_chunk] = ( + chunk_slice.start, + chunk_slice.stop - chunk_slice.start + ) + + def to_bytes(self) -> bytes: + return self.offsets_and_lengths.tobytes(order='C') + + @classmethod + def from_bytes( + cls, buffer: Union[bytes, bytearray], store: "IndexedShardedStore" + ) -> "_ShardIndex": + return cls( + store=store, + offsets_and_lengths=np.frombuffer( + bytearray(buffer), dtype=" None: + self._store = store + self._num_chunks_per_shard = functools.reduce(lambda x, y: x*y, chunks_per_shard, 1) + self._chunk_separator = chunk_separator + assert all(isinstance(s, int) for s in chunks_per_shard) + self._chunks_per_shard = tuple(chunks_per_shard) + + def _key_to_shard( + self, chunk_key: str + ) -> Tuple[str, Tuple[int, ...]]: + prefix, _, chunk_string = chunk_key.rpartition("c") + chunk_subkeys = tuple(map(int, chunk_string.split(self._chunk_separator))) + shard_key_tuple = ( + subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._chunks_per_shard) + ) + shard_key = prefix + "c" + self._chunk_separator.join(map(str, shard_key_tuple)) + return shard_key, chunk_subkeys + + def _get_index(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: + # At the end of each shard 2*64bit per chunk for offset and length define the index: + return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self) + + def _get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: + _, _, chunk_string = shard_key.rpartition("c") + shard_key_tuple = tuple(map(int, chunk_string.split(self._chunk_separator))) + for chunk_offset in itertools.product(*(range(i) for i in self._chunks_per_shard)): + yield tuple( + shard_key_i * shards_i + offset_i + for shard_key_i, offset_i, shards_i + in zip(shard_key_tuple, chunk_offset, self._chunks_per_shard) + ) + + def __getitem__(self, key: str, default: Optional[bytes] = None) -> bytes: + if _is_data_key(key): + shard_key, chunk_subkeys = self._key_to_shard(key) + full_shard_value = self._store[shard_key] + index = self._get_index(full_shard_value) + chunk_slice = index.get_chunk_slice(chunk_subkeys) + if chunk_slice is not None: + return full_shard_value[chunk_slice] + else: + if default is not None: + return default + raise KeyError(key) + else: + return self._store.__getitem__(key, default) + + def __setitem__(self, key: str, value: bytes) -> None: + if _is_data_key(key): + shard_key, chunk_subkeys = self._key_to_shard(key) + chunks_to_read = set(self._get_chunks_in_shard(shard_key)) + chunks_to_read.remove(chunk_subkeys) + new_content = {chunk_subkeys: value} + try: + full_shard_value = self._store[shard_key] + except KeyError: + index = _ShardIndex.create_empty(self) + else: + index = self._get_index(full_shard_value) + for chunk_to_read in chunks_to_read: + chunk_slice = index.get_chunk_slice(chunk_to_read) + if chunk_slice is not None: + new_content[chunk_to_read] = full_shard_value[chunk_slice] + + shard_content = b"" + for chunk_subkeys, chunk_content in new_content.items(): + chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content)) + index.set_chunk_slice(chunk_subkeys, chunk_slice) + shard_content += chunk_content + # Appending the index at the end of the shard: + shard_content += index.to_bytes() + self._store[shard_key] = shard_content + else: + self._store[key] = value + + def _shard_key_to_original_keys(self, key: str) -> Iterator[str]: + if not _is_data_key(key): + # Special keys such as meta-keys are passed on as-is + yield key + else: + index = self._get_index(self._store[key]) + prefix, _, _ = key.rpartition("c") + for chunk_tuple in self._get_chunks_in_shard(key): + if index.get_chunk_slice(chunk_tuple) is not None: + yield prefix + "c" + self._chunk_separator.join(map(str, chunk_tuple)) + + def __iter__(self) -> Iterator[str]: + for key in self._store: + yield from self._shard_key_to_original_keys(key) + + def list_prefix(self, prefix: str) -> List[str]: + if _is_data_key(prefix): + # Needs translation of the prefix to shard_key + raise NotImplementedError + return self._store.list_prefix(prefix) + + def list_dir(self, prefix: str) -> ListDirResult: + if _is_data_key(prefix): + # Needs translation of the prefix to shard_key + raise NotImplementedError + return self._store.list_dir(prefix) + + +SHARDED_STORES: Dict[str, Type[Store]] = { + "indexed": IndexedShardedStore, +}