Skip to content

Commit b6bf2dd

Browse files
committed
Merge branch 'feat/batch-creation' of github.com:d-v-b/zarr-python into feat/batch-creation
2 parents 97b768f + 986d68b commit b6bf2dd

File tree

3 files changed

+234
-20
lines changed

3 files changed

+234
-20
lines changed

src/zarr/core/group.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import logging
77
import warnings
88
from collections import defaultdict
9-
from collections.abc import AsyncIterator
9+
from collections.abc import AsyncIterator, Awaitable
1010
from dataclasses import asdict, dataclass, field, fields, replace
11+
from functools import partial
1112
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
1213

1314
import numpy as np
@@ -55,7 +56,7 @@
5556
from zarr.storage._common import ensure_no_existing_node
5657

5758
if TYPE_CHECKING:
58-
from collections.abc import AsyncGenerator, Generator, Iterable, Iterator
59+
from collections.abc import AsyncGenerator, Callable, Generator, Iterable, Iterator
5960
from typing import Any
6061

6162
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike
@@ -1266,7 +1267,7 @@ async def require_array(
12661267

12671268
async def create_nodes(
12681269
self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]
1269-
) -> tuple[tuple[str, AsyncGroup | AsyncArray]]:
1270+
) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]]:
12701271
"""
12711272
Create a set of arrays or groups rooted at this group.
12721273
"""
@@ -2817,23 +2818,36 @@ def array(
28172818
)
28182819

28192820

2820-
async def _save_metadata_return_node(
2821+
async def _with_semaphore(
2822+
func: Callable[[Any], Awaitable[T]], semaphore: asyncio.Semaphore | None = None
2823+
) -> T:
2824+
if semaphore is None:
2825+
return await func(None)
2826+
async with semaphore:
2827+
return await func(None)
2828+
2829+
2830+
async def _save_metadata(
28212831
node: AsyncArray[Any] | AsyncGroup,
28222832
) -> AsyncArray[Any] | AsyncGroup:
2823-
if isinstance(node, AsyncArray):
2824-
await node._save_metadata(node.metadata, ensure_parents=False)
2825-
else:
2826-
await node._save_metadata(ensure_parents=False)
2833+
"""
2834+
Save the metadata for an array or group, and return the array or group
2835+
"""
2836+
match node:
2837+
case AsyncArray():
2838+
await node._save_metadata(node.metadata, ensure_parents=False)
2839+
case AsyncGroup():
2840+
await node._save_metadata(ensure_parents=False)
2841+
case _:
2842+
raise ValueError(f"Unexpected node type {type(node)}")
28272843
return node
28282844

28292845

2830-
async def create_nodes_v2(
2831-
*, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata]
2832-
) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata]]]: ...
2833-
2834-
28352846
async def create_nodes(
2836-
*, store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata]
2847+
*,
2848+
store_path: StorePath,
2849+
nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata],
2850+
semaphore: asyncio.Semaphore | None = None,
28372851
) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]:
28382852
"""
28392853
Create a collection of arrays and groups concurrently and atomically. To ensure atomicity,
@@ -2850,15 +2864,18 @@ async def create_nodes(
28502864
node = AsyncGroup(value, store_path=new_store_path)
28512865
case _:
28522866
raise ValueError(f"Unexpected metadata type {type(value)}")
2853-
create_tasks.append(_save_metadata_return_node(node))
2867+
partial_func = partial(_save_metadata, node)
2868+
fut = _with_semaphore(partial_func, semaphore)
2869+
create_tasks.append(fut)
2870+
28542871
for coro in asyncio.as_completed(create_tasks):
28552872
yield await coro
28562873

28572874

28582875
T = TypeVar("T")
28592876

28602877

2861-
def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]:
2878+
def _split_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]:
28622879
"""
28632880
Given a dict of {string: T} pairs, where the keys are strings separated by some separator,
28642881
return the result of splitting each key with the separator.
@@ -2875,10 +2892,10 @@ def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T
28752892
28762893
Examples
28772894
--------
2878-
>>> _tuplize_tree({"a": 1}, separator='/')
2895+
>>> _split_keys({"a": 1}, separator='/')
28792896
{("a",): 1}
28802897
2881-
>>> _tuplize_tree({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/')
2898+
>>> _split_keys({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/')
28822899
{("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3}
28832900
"""
28842901

tests/conftest.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import pathlib
4+
from collections.abc import Iterable
45
from dataclasses import dataclass, field
56
from typing import TYPE_CHECKING
67

@@ -10,7 +11,14 @@
1011
from hypothesis import HealthCheck, Verbosity, settings
1112

1213
from zarr import AsyncGroup, config
14+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
1315
from zarr.abc.store import Store
16+
from zarr.codecs.bytes import BytesCodec
17+
from zarr.codecs.sharding import ShardingCodec
18+
from zarr.core.chunk_grids import _guess_chunks
19+
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
20+
from zarr.core.metadata.v2 import ArrayV2Metadata
21+
from zarr.core.metadata.v3 import ArrayV3Metadata
1422
from zarr.core.sync import sync
1523
from zarr.storage import FsspecStore, LocalStore, MemoryStore, StorePath, ZipStore
1624

@@ -159,3 +167,183 @@ def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat:
159167
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
160168
verbosity=Verbosity.verbose,
161169
)
170+
import numcodecs
171+
172+
173+
def meta_from_array_v2(
174+
array: np.ndarray[Any, Any],
175+
chunks: ChunkCoords | Literal["auto"] = "auto",
176+
compressor: numcodecs.abc.Codec | Literal["auto"] | None = "auto",
177+
filters: Iterable[numcodecs.abc.Codec] | Literal["auto"] = "auto",
178+
fill_value: Any = "auto",
179+
order: MemoryOrder | Literal["auto"] = "auto",
180+
dimension_separator: Literal[".", "/", "auto"] = "auto",
181+
attributes: dict[str, Any] | None = None,
182+
) -> ArrayV2Metadata:
183+
"""
184+
Create a v2 metadata object from a numpy array
185+
"""
186+
187+
_chunks = auto_chunks(chunks, array.shape, array.dtype)
188+
_compressor = auto_compressor(compressor)
189+
_filters = auto_filters(filters)
190+
_fill_value = auto_fill_value(fill_value)
191+
_order = auto_order(order)
192+
_dimension_separator = auto_dimension_separator(dimension_separator)
193+
return ArrayV2Metadata(
194+
shape=array.shape,
195+
dtype=array.dtype,
196+
chunks=_chunks,
197+
compressor=_compressor,
198+
filters=_filters,
199+
fill_value=_fill_value,
200+
order=_order,
201+
dimension_separator=_dimension_separator,
202+
attributes=attributes,
203+
)
204+
205+
206+
from typing import TypedDict
207+
208+
209+
class ChunkEncoding(TypedDict):
210+
filters: tuple[ArrayArrayCodec]
211+
compressors: tuple[BytesBytesCodec]
212+
serializer: ArrayBytesCodec
213+
214+
215+
class ChunkingSpec(TypedDict):
216+
shard_shape: tuple[int, ...]
217+
chunk_shape: tuple[int, ...] | None
218+
chunk_key_encoding: ChunkKeyEncoding
219+
220+
221+
def meta_from_array_v3(
222+
array: np.ndarray[Any, Any],
223+
shard_shape: tuple[int, ...] | Literal["auto"] | None,
224+
chunk_shape: tuple[int, ...] | Literal["auto"],
225+
serializer: ArrayBytesCodec | Literal["auto"] = "auto",
226+
compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto",
227+
filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto",
228+
fill_value: Any = "auto",
229+
chunk_key_encoding: ChunkKeyEncoding | Literal["auto"] = "auto",
230+
dimension_names: Iterable[str] | None = None,
231+
attributes: dict[str, Any] | None = None,
232+
) -> ArrayV3Metadata:
233+
_write_chunks, _read_chunks = auto_chunks_v3(
234+
shard_shape=shard_shape, chunk_shape=chunk_shape, array_shape=array.shape, dtype=array.dtype
235+
)
236+
_codecs = auto_codecs(serializer=serializer, compressors=compressors, filters=filters)
237+
if _read_chunks is not None:
238+
_codecs = (ShardingCodec(codecs=_codecs, chunk_shape=_read_chunks),)
239+
240+
_fill_value = auto_fill_value(fill_value)
241+
_chunk_key_encoding = auto_chunk_key_encoding(chunk_key_encoding)
242+
return ArrayV3Metadata(
243+
shape=array.shape,
244+
dtype=array.dtype,
245+
codecs=_codecs,
246+
chunk_key_encoding=_chunk_key_encoding,
247+
fill_value=fill_value,
248+
chunk_grid={"name": "regular", "config": {"chunk_shape": shard_shape}},
249+
attributes=attributes,
250+
dimension_names=dimension_names,
251+
)
252+
253+
254+
from zarr.abc.codec import Codec
255+
from zarr.codecs import ZstdCodec
256+
257+
258+
def auto_codecs(
259+
*,
260+
filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto",
261+
compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto",
262+
serializer: ArrayBytesCodec | Literal["auto"] = "auto",
263+
) -> tuple[Codec, ...]:
264+
"""
265+
Heuristically generate a tuple of codecs
266+
"""
267+
_compressors: tuple[BytesBytesCodec, ...]
268+
_filters: tuple[ArrayArrayCodec, ...]
269+
_serializer: ArrayBytesCodec
270+
if filters == "auto":
271+
_filters = ()
272+
else:
273+
_filters = tuple(filters)
274+
275+
if compressors == "auto":
276+
_compressors = (ZstdCodec(level=3),)
277+
else:
278+
_compressors = tuple(compressors)
279+
280+
if serializer == "auto":
281+
_serializer = BytesCodec()
282+
else:
283+
_serializer = serializer
284+
return (*_filters, _serializer, *_compressors)
285+
286+
287+
def auto_dimension_separator(dimension_separator: Literal[".", "/", "auto"]) -> Literal[".", "/"]:
288+
if dimension_separator == "auto":
289+
return "/"
290+
return dimension_separator
291+
292+
293+
def auto_order(order: MemoryOrder | Literal["auto"]) -> MemoryOrder:
294+
if order == "auto":
295+
return "C"
296+
return order
297+
298+
299+
def auto_fill_value(fill_value: Any) -> Any:
300+
if fill_value == "auto":
301+
return 0
302+
return fill_value
303+
304+
305+
def auto_compressor(
306+
compressor: numcodecs.abc.Codec | Literal["auto"] | None,
307+
) -> numcodecs.abc.Codec | None:
308+
if compressor == "auto":
309+
return numcodecs.Zstd(level=3)
310+
return compressor
311+
312+
313+
def auto_filters(
314+
filters: Iterable[numcodecs.abc.Codec] | Literal["auto"],
315+
) -> tuple[numcodecs.abc.Codec, ...]:
316+
if filters == "auto":
317+
return ()
318+
return tuple(filters)
319+
320+
321+
def auto_chunks(
322+
chunks: tuple[int, ...] | Literal["auto"], shape: tuple[int, ...], dtype: npt.DTypeLike
323+
) -> tuple[int, ...]:
324+
if chunks == "auto":
325+
return _guess_chunks(shape, np.dtype(dtype).itemsize)
326+
return chunks
327+
328+
329+
def auto_chunks_v3(
330+
*,
331+
shard_shape: tuple[int, ...] | Literal["auto"],
332+
chunk_shape: tuple[int, ...] | Literal["auto"] | None,
333+
array_shape: tuple[int, ...],
334+
dtype: npt.DTypeLike,
335+
) -> tuple[tuple[int, ...], tuple[int, ...] | None]:
336+
match (shard_shape, chunk_shape):
337+
case ("auto", "auto"):
338+
# stupid default but easy to think about
339+
return ((256,) * len(array_shape), (64,) * len(array_shape))
340+
case ("auto", None):
341+
return (_guess_chunks(array_shape, np.dtype(dtype).itemsize), None)
342+
case ("auto", _):
343+
return (chunk_shape, chunk_shape)
344+
case (_, None):
345+
return (shard_shape, None)
346+
case (_, "auto"):
347+
return (shard_shape, shard_shape)
348+
case _:
349+
return (shard_shape, chunk_shape)

tests/test_group.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from zarr.abc.store import Store
1919
from zarr.core._info import GroupInfo
2020
from zarr.core.buffer import default_buffer_prototype
21-
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
21+
from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_nodes
2222
from zarr.core.sync import sync
2323
from zarr.errors import ContainsArrayError, ContainsGroupError
2424
from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore, make_store_path
2525

26-
from .conftest import parse_store
26+
from .conftest import meta_from_array_v2, parse_store
2727

2828
if TYPE_CHECKING:
2929
from _pytest.compat import LEGACY_PATH
@@ -1440,6 +1440,15 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None
14401440
g1["0/0"]
14411441

14421442

1443+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1444+
async def test_create_nodes(store: Store) -> None:
1445+
"""
1446+
Ensure that create_nodes works.
1447+
"""
1448+
arrays = {str(idx): meta_from_array_v2(np.arange(idx)) for idx in range(1, 5)}
1449+
spath = await make_store_path(store, path="foo")
1450+
results = [a async for a in create_nodes(store_path=spath, nodes=arrays)]
1451+
14431452
@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])
14441453
def test_deprecated_compressor(store: Store) -> None:
14451454
g = zarr.group(store=store, zarr_format=2)

0 commit comments

Comments
 (0)