Skip to content

Commit 6f25f82

Browse files
committed
add support for async vindex
1 parent 535ebaa commit 6f25f82

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/zarr/core/array.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
)
6969
from zarr.core.config import config as zarr_config
7070
from zarr.core.indexing import (
71+
AsyncOIndex,
72+
AsyncVIndex,
7173
BasicIndexer,
7274
BasicSelection,
7375
BlockIndex,
@@ -79,7 +81,6 @@
7981
MaskIndexer,
8082
MaskSelection,
8183
OIndex,
82-
AsyncOIndex,
8384
OrthogonalIndexer,
8485
OrthogonalSelection,
8586
Selection,
@@ -1374,6 +1375,27 @@ async def get_orthogonal_selection(
13741375
indexer=indexer, out=out, fields=fields, prototype=prototype
13751376
)
13761377

1378+
@_deprecate_positional_args
1379+
async def get_coordinate_selection(
1380+
self,
1381+
selection: CoordinateSelection,
1382+
*,
1383+
out: NDBuffer | None = None,
1384+
fields: Fields | None = None,
1385+
prototype: BufferPrototype | None = None,
1386+
) -> NDArrayLikeOrScalar:
1387+
if prototype is None:
1388+
prototype = default_buffer_prototype()
1389+
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
1390+
out_array = await self._get_selection(
1391+
indexer=indexer, out=out, fields=fields, prototype=prototype
1392+
)
1393+
1394+
if hasattr(out_array, "shape"):
1395+
# restore shape
1396+
out_array = np.array(out_array).reshape(indexer.sel_shape)
1397+
return out_array
1398+
13771399
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
13781400
"""
13791401
Asynchronously save the array metadata.
@@ -1510,6 +1532,10 @@ def oindex(self) -> AsyncOIndex:
15101532
:func:`set_orthogonal_selection` for documentation and examples."""
15111533
return AsyncOIndex(self)
15121534

1535+
@property
1536+
def vindex(self) -> AsyncVIndex:
1537+
return AsyncVIndex(self)
1538+
15131539
async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
15141540
"""
15151541
Asynchronously resize the array to a new shape.

src/zarr/core/indexing.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def __getitem__(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrSc
950950
return self.array.get_orthogonal_selection(
951951
cast(OrthogonalSelection, new_selection), fields=fields
952952
)
953-
953+
954954
def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> None:
955955
fields, new_selection = pop_fields(selection)
956956
new_selection = ensure_tuple(new_selection)
@@ -1287,6 +1287,30 @@ def __setitem__(
12871287
raise VindexInvalidSelectionError(new_selection)
12881288

12891289

1290+
@dataclass(frozen=True)
1291+
class AsyncVIndex:
1292+
array: AsyncArray
1293+
1294+
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
1295+
async def getitem(
1296+
self, selection: CoordinateSelection | MaskSelection | Array
1297+
) -> NDArrayLikeOrScalar:
1298+
from zarr.core.array import Array
1299+
1300+
# if input is a Zarr array, we materialize it now.
1301+
if isinstance(selection, Array):
1302+
selection = _zarr_array_to_int_or_bool_array(selection)
1303+
fields, new_selection = pop_fields(selection)
1304+
new_selection = ensure_tuple(new_selection)
1305+
new_selection = replace_lists(new_selection)
1306+
if is_coordinate_selection(new_selection, self.array.shape):
1307+
return await self.array.get_coordinate_selection(new_selection, fields=fields)
1308+
elif is_mask_selection(new_selection, self.array.shape):
1309+
return self.array.get_mask_selection(new_selection, fields=fields)
1310+
else:
1311+
raise VindexInvalidSelectionError(new_selection)
1312+
1313+
12901314
def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]:
12911315
# early out
12921316
if fields is None:

0 commit comments

Comments
 (0)