Skip to content

Commit 4f51d23

Browse files
committed
add oindex method to AsyncArray
1 parent 481550a commit 4f51d23

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/zarr/core/array.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
MaskIndexer,
8080
MaskSelection,
8181
OIndex,
82+
AsyncOIndex,
8283
OrthogonalIndexer,
8384
OrthogonalSelection,
8485
Selection,
@@ -1358,6 +1359,21 @@ async def getitem(
13581359
)
13591360
return await self._get_selection(indexer, prototype=prototype)
13601361

1362+
async def get_orthogonal_selection(
1363+
self,
1364+
selection: OrthogonalSelection,
1365+
*,
1366+
out: NDBuffer | None = None,
1367+
fields: Fields | None = None,
1368+
prototype: BufferPrototype | None = None,
1369+
) -> NDArrayLike:
1370+
if prototype is None:
1371+
prototype = default_buffer_prototype()
1372+
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
1373+
return await self._async_array._get_selection(
1374+
indexer=indexer, out=out, fields=fields, prototype=prototype
1375+
)
1376+
13611377
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
13621378
"""
13631379
Asynchronously save the array metadata.
@@ -1488,6 +1504,12 @@ async def setitem(
14881504
)
14891505
return await self._set_selection(indexer, value, prototype=prototype)
14901506

1507+
@property
1508+
def oindex(self) -> AsyncOIndex:
1509+
"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
1510+
:func:`set_orthogonal_selection` for documentation and examples."""
1511+
return AsyncOIndex(self)
1512+
14911513
async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
14921514
"""
14931515
Asynchronously resize the array to a new shape.

src/zarr/core/indexing.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from zarr.core.common import product
2929

3030
if TYPE_CHECKING:
31-
from zarr.core.array import Array
31+
from zarr.core.array import Array, AsyncArray
3232
from zarr.core.buffer import NDArrayLikeOrScalar
3333
from zarr.core.chunk_grids import ChunkGrid
3434
from zarr.core.common import ChunkCoords
@@ -950,7 +950,7 @@ def __getitem__(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrSc
950950
return self.array.get_orthogonal_selection(
951951
cast(OrthogonalSelection, new_selection), fields=fields
952952
)
953-
953+
954954
def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> None:
955955
fields, new_selection = pop_fields(selection)
956956
new_selection = ensure_tuple(new_selection)
@@ -960,6 +960,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
960960
)
961961

962962

963+
@dataclass(frozen=True)
964+
class AsyncOIndex:
965+
array: AsyncArray
966+
967+
async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLike:
968+
from zarr.core.array import Array
969+
970+
# if input is a Zarr array, we materialize it now.
971+
if isinstance(selection, Array):
972+
selection = _zarr_array_to_int_or_bool_array(selection)
973+
974+
fields, new_selection = pop_fields(selection)
975+
new_selection = ensure_tuple(new_selection)
976+
new_selection = replace_lists(new_selection)
977+
return await self.array.get_orthogonal_selection(
978+
cast(OrthogonalSelection, new_selection), fields=fields
979+
)
980+
981+
963982
@dataclass(frozen=True)
964983
class BlockIndexer(Indexer):
965984
dim_indexers: list[SliceDimIndexer]

0 commit comments

Comments
 (0)