Skip to content

Commit 8a976d6

Browse files
committed
revert removal of metadata chunk grid attribute
1 parent e4a0372 commit 8a976d6

File tree

5 files changed

+40
-42
lines changed

5 files changed

+40
-42
lines changed

src/zarr/core/array.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from asyncio import gather
66
from collections.abc import Iterable
77
from dataclasses import dataclass, field, replace
8-
from functools import cached_property
98
from itertools import starmap
109
from logging import getLogger
1110
from typing import (
@@ -32,7 +31,7 @@
3231
from zarr.codecs._v2 import V2Codec
3332
from zarr.codecs.bytes import BytesCodec
3433
from zarr.core._info import ArrayInfo
35-
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, ArraySpec, parse_array_config
34+
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, parse_array_config
3635
from zarr.core.attributes import Attributes
3736
from zarr.core.buffer import (
3837
BufferPrototype,
@@ -42,7 +41,7 @@
4241
default_buffer_prototype,
4342
)
4443
from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype
45-
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid, _auto_partition, normalize_chunks
44+
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
4645
from zarr.core.chunk_key_encodings import (
4746
ChunkKeyEncoding,
4847
ChunkKeyEncodingLike,
@@ -951,13 +950,6 @@ def chunks(self) -> ChunkCoords:
951950
"""
952951
return self.metadata.chunks
953952

954-
@cached_property
955-
def chunk_grid(self) -> ChunkGrid:
956-
if self.metadata.zarr_format == 2:
957-
return RegularChunkGrid(chunk_shape=self.chunks)
958-
else:
959-
return self.metadata.chunk_grid
960-
961953
@property
962954
def shards(self) -> ChunkCoords | None:
963955
"""Returns the shard shape of the Array.
@@ -1281,20 +1273,6 @@ def nbytes(self) -> int:
12811273
"""
12821274
return self.size * self.dtype.itemsize
12831275

1284-
def get_chunk_spec(
1285-
self, _chunk_coords: ChunkCoords, array_config: ArrayConfig, prototype: BufferPrototype
1286-
) -> ArraySpec:
1287-
assert isinstance(self.chunk_grid, RegularChunkGrid), (
1288-
"Currently, only regular chunk grid is supported"
1289-
)
1290-
return ArraySpec(
1291-
shape=self.chunk_grid.chunk_shape,
1292-
dtype=self._zdtype,
1293-
fill_value=self.metadata.fill_value,
1294-
config=array_config,
1295-
prototype=prototype,
1296-
)
1297-
12981276
async def _get_selection(
12991277
self,
13001278
indexer: Indexer,
@@ -1334,7 +1312,7 @@ async def _get_selection(
13341312
[
13351313
(
13361314
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
1337-
self.get_chunk_spec(chunk_coords, _config, prototype=prototype),
1315+
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
13381316
chunk_selection,
13391317
out_selection,
13401318
is_complete_chunk,
@@ -1389,7 +1367,7 @@ async def getitem(
13891367
indexer = BasicIndexer(
13901368
selection,
13911369
shape=self.metadata.shape,
1392-
chunk_grid=self.chunk_grid,
1370+
chunk_grid=self.metadata.chunk_grid,
13931371
)
13941372
return await self._get_selection(indexer, prototype=prototype)
13951373

@@ -1464,7 +1442,7 @@ async def _set_selection(
14641442
[
14651443
(
14661444
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
1467-
self.get_chunk_spec(chunk_coords, _config, prototype),
1445+
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
14681446
chunk_selection,
14691447
out_selection,
14701448
is_complete_chunk,
@@ -1519,7 +1497,7 @@ async def setitem(
15191497
indexer = BasicIndexer(
15201498
selection,
15211499
shape=self.metadata.shape,
1522-
chunk_grid=self.chunk_grid,
1500+
chunk_grid=self.metadata.chunk_grid,
15231501
)
15241502
return await self._set_selection(indexer, value, prototype=prototype)
15251503

@@ -1556,8 +1534,8 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True)
15561534

15571535
if delete_outside_chunks:
15581536
# Remove all chunks outside of the new shape
1559-
old_chunk_coords = set(self.chunk_grid.all_chunk_coords(self.metadata.shape))
1560-
new_chunk_coords = set(self.chunk_grid.all_chunk_coords(new_shape))
1537+
old_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(self.metadata.shape))
1538+
new_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(new_shape))
15611539

15621540
async def _delete_key(key: str) -> None:
15631541
await (self.store_path / key).delete()
@@ -2687,7 +2665,7 @@ def get_basic_selection(
26872665
prototype = default_buffer_prototype()
26882666
return sync(
26892667
self._async_array._get_selection(
2690-
BasicIndexer(selection, self.shape, self._async_array.chunk_grid),
2668+
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
26912669
out=out,
26922670
fields=fields,
26932671
prototype=prototype,
@@ -2787,7 +2765,7 @@ def set_basic_selection(
27872765
"""
27882766
if prototype is None:
27892767
prototype = default_buffer_prototype()
2790-
indexer = BasicIndexer(selection, self.shape, self._async_array.chunk_grid)
2768+
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
27912769
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
27922770

27932771
@_deprecate_positional_args
@@ -2908,7 +2886,7 @@ def get_orthogonal_selection(
29082886
"""
29092887
if prototype is None:
29102888
prototype = default_buffer_prototype()
2911-
indexer = OrthogonalIndexer(selection, self.shape, self._async_array.chunk_grid)
2889+
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
29122890
return sync(
29132891
self._async_array._get_selection(
29142892
indexer=indexer, out=out, fields=fields, prototype=prototype
@@ -3021,7 +2999,7 @@ def set_orthogonal_selection(
30212999
"""
30223000
if prototype is None:
30233001
prototype = default_buffer_prototype()
3024-
indexer = OrthogonalIndexer(selection, self.shape, self._async_array.chunk_grid)
3002+
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
30253003
return sync(
30263004
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
30273005
)
@@ -3102,7 +3080,7 @@ def get_mask_selection(
31023080

31033081
if prototype is None:
31043082
prototype = default_buffer_prototype()
3105-
indexer = MaskIndexer(mask, self.shape, self._async_array.chunk_grid)
3083+
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
31063084
return sync(
31073085
self._async_array._get_selection(
31083086
indexer=indexer, out=out, fields=fields, prototype=prototype
@@ -3185,7 +3163,7 @@ def set_mask_selection(
31853163
"""
31863164
if prototype is None:
31873165
prototype = default_buffer_prototype()
3188-
indexer = MaskIndexer(mask, self.shape, self._async_array.chunk_grid)
3166+
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
31893167
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
31903168

31913169
@_deprecate_positional_args
@@ -3266,7 +3244,7 @@ def get_coordinate_selection(
32663244
"""
32673245
if prototype is None:
32683246
prototype = default_buffer_prototype()
3269-
indexer = CoordinateIndexer(selection, self.shape, self._async_array.chunk_grid)
3247+
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
32703248
out_array = sync(
32713249
self._async_array._get_selection(
32723250
indexer=indexer, out=out, fields=fields, prototype=prototype
@@ -3352,7 +3330,7 @@ def set_coordinate_selection(
33523330
if prototype is None:
33533331
prototype = default_buffer_prototype()
33543332
# setup indexer
3355-
indexer = CoordinateIndexer(selection, self.shape, self._async_array.chunk_grid)
3333+
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
33563334

33573335
# handle value - need ndarray-like flatten value
33583336
if not is_scalar(value, self.dtype):
@@ -3468,7 +3446,7 @@ def get_block_selection(
34683446
"""
34693447
if prototype is None:
34703448
prototype = default_buffer_prototype()
3471-
indexer = BlockIndexer(selection, self.shape, self._async_array.chunk_grid)
3449+
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
34723450
return sync(
34733451
self._async_array._get_selection(
34743452
indexer=indexer, out=out, fields=fields, prototype=prototype
@@ -3562,7 +3540,7 @@ def set_block_selection(
35623540
"""
35633541
if prototype is None:
35643542
prototype = default_buffer_prototype()
3565-
indexer = BlockIndexer(selection, self.shape, self._async_array.chunk_grid)
3543+
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
35663544
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
35673545

35683546
@property

src/zarr/core/metadata/v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import base64
44
import warnings
55
from collections.abc import Iterable, Sequence
6+
from functools import cached_property
67
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict
78

89
import numcodecs.abc
910

1011
from zarr.abc.metadata import Metadata
12+
from zarr.core.chunk_grids import RegularChunkGrid
1113
from zarr.core.dtype import get_data_type_from_native_dtype
1214
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, TDType_co, TScalar_co, ZDType
1315

@@ -103,6 +105,10 @@ def __init__(
103105
def ndim(self) -> int:
104106
return len(self.shape)
105107

108+
@cached_property
109+
def chunk_grid(self) -> RegularChunkGrid:
110+
return RegularChunkGrid(chunk_shape=self.chunks)
111+
106112
@property
107113
def shards(self) -> ChunkCoords | None:
108114
return None

src/zarr/core/metadata/v3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,20 @@ def inner_codecs(self) -> tuple[Codec, ...]:
269269
return self.codecs[0].codecs
270270
return self.codecs
271271

272+
def get_chunk_spec(
273+
self, _chunk_coords: ChunkCoords, array_config: ArrayConfig, prototype: BufferPrototype
274+
) -> ArraySpec:
275+
assert isinstance(self.chunk_grid, RegularChunkGrid), (
276+
"Currently, only regular chunk grid is supported"
277+
)
278+
return ArraySpec(
279+
shape=self.chunk_grid.chunk_shape,
280+
dtype=self.dtype,
281+
fill_value=self.fill_value,
282+
config=array_config,
283+
prototype=prototype,
284+
)
285+
272286
def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
273287
return self.chunk_key_encoding.encode_chunk_key(chunk_coords)
274288

tests/test_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,7 @@ async def test_with_data(impl: Literal["sync", "async"], store: Store) -> None:
13631363
elif impl == "async":
13641364
arr = await create_array(store, name=name, data=data, zarr_format=3)
13651365
stored = await arr._get_selection(
1366-
BasicIndexer(..., shape=arr.shape, chunk_grid=arr.chunk_grid),
1366+
BasicIndexer(..., shape=arr.shape, chunk_grid=arr.metadata.chunk_grid),
13671367
prototype=default_buffer_prototype(),
13681368
)
13691369
else:

tests/test_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ async def test_asyncgroup_create_array(
10071007
assert subnode.dtype == dtype
10081008
# todo: fix the type annotation of array.metadata.chunk_grid so that we get some autocomplete
10091009
# here.
1010-
assert subnode.chunk_grid.chunk_shape == chunk_shape
1010+
assert subnode.metadata.chunk_grid.chunk_shape == chunk_shape
10111011
assert subnode.metadata.zarr_format == zarr_format
10121012

10131013

0 commit comments

Comments
 (0)