Skip to content

Commit fff6a2c

Browse files
authored
Merge branch 'v3' into hierarchy_api
2 parents b98c06c + 661acb3 commit fff6a2c

File tree

25 files changed

+438
-252
lines changed

25 files changed

+438
-252
lines changed

src/zarr/abc/codec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
if TYPE_CHECKING:
1414
from typing_extensions import Self
1515

16-
from zarr.common import ArraySpec
16+
from zarr.array_spec import ArraySpec
1717
from zarr.indexing import SelectorTuple
1818
from zarr.metadata import ArrayMetadata
1919

20-
2120
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
2221
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
2322

src/zarr/abc/store.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import AsyncGenerator
33
from typing import Protocol, runtime_checkable
44

5-
from zarr.buffer import Buffer
5+
from zarr.buffer import Buffer, BufferPrototype
66
from zarr.common import BytesLike, OpenMode
77

88

@@ -30,7 +30,10 @@ def _check_writable(self) -> None:
3030

3131
@abstractmethod
3232
async def get(
33-
self, key: str, byte_range: tuple[int | None, int | None] | None = None
33+
self,
34+
key: str,
35+
prototype: BufferPrototype,
36+
byte_range: tuple[int | None, int | None] | None = None,
3437
) -> Buffer | None:
3538
"""Retrieve the value associated with a given key.
3639
@@ -47,7 +50,9 @@ async def get(
4750

4851
@abstractmethod
4952
async def get_partial_values(
50-
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
53+
self,
54+
prototype: BufferPrototype,
55+
key_ranges: list[tuple[str, tuple[int | None, int | None]]],
5156
) -> list[Buffer | None]:
5257
"""Retrieve possibly partial values from given key_ranges.
5358
@@ -175,12 +180,16 @@ def close(self) -> None: # noqa: B027
175180

176181
@runtime_checkable
177182
class ByteGetter(Protocol):
178-
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
183+
async def get(
184+
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
185+
) -> Buffer | None: ...
179186

180187

181188
@runtime_checkable
182189
class ByteSetter(Protocol):
183-
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
190+
async def get(
191+
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
192+
) -> Buffer | None: ...
184193

185194
async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ...
186195

src/zarr/array.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from zarr.abc.codec import Codec
2121
from zarr.abc.store import set_or_delete
2222
from zarr.attributes import Attributes
23-
from zarr.buffer import Factory, NDArrayLike, NDBuffer
23+
from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
2424
from zarr.chunk_grids import RegularChunkGrid
2525
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
2626
from zarr.codecs import BytesCodec
@@ -414,8 +414,8 @@ async def _get_selection(
414414
self,
415415
indexer: Indexer,
416416
*,
417+
prototype: BufferPrototype,
417418
out: NDBuffer | None = None,
418-
factory: Factory.Create = NDBuffer.create,
419419
fields: Fields | None = None,
420420
) -> NDArrayLike:
421421
# check fields are sensible
@@ -432,7 +432,7 @@ async def _get_selection(
432432
f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}"
433433
)
434434
else:
435-
out_buffer = factory(
435+
out_buffer = prototype.nd_buffer.create(
436436
shape=indexer.shape,
437437
dtype=out_dtype,
438438
order=self.order,
@@ -444,7 +444,7 @@ async def _get_selection(
444444
[
445445
(
446446
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
447-
self.metadata.get_chunk_spec(chunk_coords, self.order),
447+
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype=prototype),
448448
chunk_selection,
449449
out_selection,
450450
)
@@ -456,14 +456,14 @@ async def _get_selection(
456456
return out_buffer.as_ndarray_like()
457457

458458
async def getitem(
459-
self, selection: Selection, *, factory: Factory.Create = NDBuffer.create
459+
self, selection: Selection, *, prototype: BufferPrototype = default_buffer_prototype
460460
) -> NDArrayLike:
461461
indexer = BasicIndexer(
462462
selection,
463463
shape=self.metadata.shape,
464464
chunk_grid=self.metadata.chunk_grid,
465465
)
466-
return await self._get_selection(indexer, factory=factory)
466+
return await self._get_selection(indexer, prototype=prototype)
467467

468468
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
469469
to_save = metadata.to_buffer_dict()
@@ -475,7 +475,7 @@ async def _set_selection(
475475
indexer: Indexer,
476476
value: NDArrayLike,
477477
*,
478-
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
478+
prototype: BufferPrototype,
479479
fields: Fields | None = None,
480480
) -> None:
481481
# check fields are sensible
@@ -497,14 +497,14 @@ async def _set_selection(
497497
# We accept any ndarray like object from the user and convert it
498498
# to a NDBuffer (or subclass). From this point onwards, we only pass
499499
# Buffer and NDBuffer between components.
500-
value_buffer = factory(value)
500+
value_buffer = prototype.nd_buffer.from_ndarray_like(value)
501501

502502
# merging with existing data and encoding chunks
503503
await self.metadata.codec_pipeline.write(
504504
[
505505
(
506506
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
507-
self.metadata.get_chunk_spec(chunk_coords, self.order),
507+
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype),
508508
chunk_selection,
509509
out_selection,
510510
)
@@ -518,14 +518,14 @@ async def setitem(
518518
self,
519519
selection: Selection,
520520
value: NDArrayLike,
521-
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
521+
prototype: BufferPrototype = default_buffer_prototype,
522522
) -> None:
523523
indexer = BasicIndexer(
524524
selection,
525525
shape=self.metadata.shape,
526526
chunk_grid=self.metadata.chunk_grid,
527527
)
528-
return await self._set_selection(indexer, value, factory=factory)
528+
return await self._set_selection(indexer, value, prototype=prototype)
529529

530530
async def resize(
531531
self, new_shape: ChunkCoords, delete_outside_chunks: bool = True
@@ -712,7 +712,9 @@ def __setitem__(self, selection: Selection, value: NDArrayLike) -> None:
712712
def get_basic_selection(
713713
self,
714714
selection: BasicSelection = Ellipsis,
715+
*,
715716
out: NDBuffer | None = None,
717+
prototype: BufferPrototype = default_buffer_prototype,
716718
fields: Fields | None = None,
717719
) -> NDArrayLike:
718720
if self.shape == ():
@@ -723,57 +725,101 @@ def get_basic_selection(
723725
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
724726
out=out,
725727
fields=fields,
728+
prototype=prototype,
726729
)
727730
)
728731

729732
def set_basic_selection(
730-
self, selection: BasicSelection, value: NDArrayLike, fields: Fields | None = None
733+
self,
734+
selection: BasicSelection,
735+
value: NDArrayLike,
736+
*,
737+
fields: Fields | None = None,
738+
prototype: BufferPrototype = default_buffer_prototype,
731739
) -> None:
732740
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
733-
sync(self._async_array._set_selection(indexer, value, fields=fields))
741+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
734742

735743
def get_orthogonal_selection(
736744
self,
737745
selection: OrthogonalSelection,
746+
*,
738747
out: NDBuffer | None = None,
739748
fields: Fields | None = None,
749+
prototype: BufferPrototype = default_buffer_prototype,
740750
) -> NDArrayLike:
741751
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
742-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
752+
return sync(
753+
self._async_array._get_selection(
754+
indexer=indexer, out=out, fields=fields, prototype=prototype
755+
)
756+
)
743757

744758
def set_orthogonal_selection(
745-
self, selection: OrthogonalSelection, value: NDArrayLike, fields: Fields | None = None
759+
self,
760+
selection: OrthogonalSelection,
761+
value: NDArrayLike,
762+
*,
763+
fields: Fields | None = None,
764+
prototype: BufferPrototype = default_buffer_prototype,
746765
) -> None:
747766
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
748-
return sync(self._async_array._set_selection(indexer, value, fields=fields))
767+
return sync(
768+
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
769+
)
749770

750771
def get_mask_selection(
751-
self, mask: MaskSelection, out: NDBuffer | None = None, fields: Fields | None = None
772+
self,
773+
mask: MaskSelection,
774+
*,
775+
out: NDBuffer | None = None,
776+
fields: Fields | None = None,
777+
prototype: BufferPrototype = default_buffer_prototype,
752778
) -> NDArrayLike:
753779
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
754-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
780+
return sync(
781+
self._async_array._get_selection(
782+
indexer=indexer, out=out, fields=fields, prototype=prototype
783+
)
784+
)
755785

756786
def set_mask_selection(
757-
self, mask: MaskSelection, value: NDArrayLike, fields: Fields | None = None
787+
self,
788+
mask: MaskSelection,
789+
value: NDArrayLike,
790+
*,
791+
fields: Fields | None = None,
792+
prototype: BufferPrototype = default_buffer_prototype,
758793
) -> None:
759794
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
760-
sync(self._async_array._set_selection(indexer, value, fields=fields))
795+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
761796

762797
def get_coordinate_selection(
763798
self,
764799
selection: CoordinateSelection,
800+
*,
765801
out: NDBuffer | None = None,
766802
fields: Fields | None = None,
803+
prototype: BufferPrototype = default_buffer_prototype,
767804
) -> NDArrayLike:
768805
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
769-
out_array = sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
806+
out_array = sync(
807+
self._async_array._get_selection(
808+
indexer=indexer, out=out, fields=fields, prototype=prototype
809+
)
810+
)
770811

771812
# restore shape
772813
out_array = out_array.reshape(indexer.sel_shape)
773814
return out_array
774815

775816
def set_coordinate_selection(
776-
self, selection: CoordinateSelection, value: NDArrayLike, fields: Fields | None = None
817+
self,
818+
selection: CoordinateSelection,
819+
value: NDArrayLike,
820+
*,
821+
fields: Fields | None = None,
822+
prototype: BufferPrototype = default_buffer_prototype,
777823
) -> None:
778824
# setup indexer
779825
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
@@ -790,25 +836,33 @@ def set_coordinate_selection(
790836
if hasattr(value, "shape") and len(value.shape) > 1:
791837
value = value.reshape(-1)
792838

793-
sync(self._async_array._set_selection(indexer, value, fields=fields))
839+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
794840

795841
def get_block_selection(
796842
self,
797843
selection: BlockSelection,
844+
*,
798845
out: NDBuffer | None = None,
799846
fields: Fields | None = None,
847+
prototype: BufferPrototype = default_buffer_prototype,
800848
) -> NDArrayLike:
801849
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
802-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
850+
return sync(
851+
self._async_array._get_selection(
852+
indexer=indexer, out=out, fields=fields, prototype=prototype
853+
)
854+
)
803855

804856
def set_block_selection(
805857
self,
806858
selection: BlockSelection,
807859
value: NDArrayLike,
860+
*,
808861
fields: Fields | None = None,
862+
prototype: BufferPrototype = default_buffer_prototype,
809863
) -> None:
810864
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
811-
sync(self._async_array._set_selection(indexer, value, fields=fields))
865+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
812866

813867
@property
814868
def vindex(self) -> VIndex:

src/zarr/array_spec.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Literal
5+
6+
import numpy as np
7+
8+
from zarr.buffer import BufferPrototype
9+
from zarr.common import ChunkCoords, parse_dtype, parse_fill_value, parse_order, parse_shapelike
10+
11+
12+
@dataclass(frozen=True)
13+
class ArraySpec:
14+
shape: ChunkCoords
15+
dtype: np.dtype[Any]
16+
fill_value: Any
17+
order: Literal["C", "F"]
18+
prototype: BufferPrototype
19+
20+
def __init__(
21+
self,
22+
shape: ChunkCoords,
23+
dtype: np.dtype[Any],
24+
fill_value: Any,
25+
order: Literal["C", "F"],
26+
prototype: BufferPrototype,
27+
) -> None:
28+
shape_parsed = parse_shapelike(shape)
29+
dtype_parsed = parse_dtype(dtype)
30+
fill_value_parsed = parse_fill_value(fill_value)
31+
order_parsed = parse_order(order)
32+
33+
object.__setattr__(self, "shape", shape_parsed)
34+
object.__setattr__(self, "dtype", dtype_parsed)
35+
object.__setattr__(self, "fill_value", fill_value_parsed)
36+
object.__setattr__(self, "order", order_parsed)
37+
object.__setattr__(self, "prototype", prototype)
38+
39+
@property
40+
def ndim(self) -> int:
41+
return len(self.shape)

0 commit comments

Comments
 (0)