88import warnings
99from abc import abstractmethod
1010from collections .abc import Sequence
11- from dataclasses import dataclass
11+ from dataclasses import dataclass , replace
1212from functools import cached_property , reduce
13- from typing import TYPE_CHECKING , Any , Literal , TypeAlias , TypedDict , Union
13+ from typing import TYPE_CHECKING , Any , Literal , TypeAlias , TypedDict , Union , cast
1414
1515import numpy as np
1616
3232
3333
3434# Type alias for chunk edge length specification
35- # Can be either an integer or a run-length encoded tuple [value, count]
36- ChunkEdgeLength = int | tuple [int , int ]
35+ # Can be either an integer or a run-length encoded pair [value, count]
36+ # The pair can be a tuple or list (common in JSON/test code)
37+ ChunkEdgeLength = int | tuple [int , int ] | list [int ]
3738
3839# User-facing chunk specification types
3940# Note: ChunkGrid is defined later in this file but can be used via string literal
4041ChunksLike : TypeAlias = Union [
4142 tuple [int , ...], # Regular chunks: (10, 10) → RegularChunkGrid
4243 int , # Uniform chunks: 10 → RegularChunkGrid
43- Sequence [Sequence [int ]], # Variable chunks: [[10,20],[5,5]] → RectilinearChunkGrid
44+ Sequence [Sequence [ChunkEdgeLength ]], # Variable chunks with optional RLE → RectilinearChunkGrid
4445 "ChunkGrid" , # Explicit ChunkGrid instance (forward reference)
4546 Literal ["auto" ], # Auto-chunking → RegularChunkGrid
4647]
@@ -295,6 +296,7 @@ def _normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tup
295296 chunks = tuple (int (chunks ) for _ in shape )
296297
297298 # handle dask-style chunks (iterable of iterables)
299+ # TODO
298300 if all (isinstance (c , (tuple | list )) for c in chunks ):
299301 # Check for irregular chunks and warn user
300302 for dim_idx , c in enumerate (chunks ):
@@ -338,13 +340,20 @@ def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid:
338340 if isinstance (data , ChunkGrid ):
339341 return data
340342
341- name_parsed , _ = parse_named_configuration (data )
343+ # After isinstance check, data must be dict[str, JSON]
344+ # Cast needed for older mypy versions that don't narrow types properly
345+ data_dict = cast (dict [str , JSON ], data ) # type: ignore[redundant-cast]
346+ name_parsed , _ = parse_named_configuration (data_dict )
342347 if name_parsed == "regular" :
343- return RegularChunkGrid ._from_dict (data )
348+ return RegularChunkGrid ._from_dict (data_dict )
344349 elif name_parsed == "rectilinear" :
345- return RectilinearChunkGrid ._from_dict (data )
350+ return RectilinearChunkGrid ._from_dict (data_dict )
346351 raise ValueError (f"Unknown chunk grid. Got { name_parsed } ." )
347352
353+ @abstractmethod
354+ def update_shape (self , new_shape : tuple [int , ...]) -> Self :
355+ pass
356+
348357 @abstractmethod
349358 def all_chunk_coords (self , array_shape : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
350359 pass
@@ -466,6 +475,9 @@ def _from_dict(cls, data: dict[str, JSON]) -> Self:
466475 def to_dict (self ) -> dict [str , JSON ]:
467476 return {"name" : "regular" , "configuration" : {"chunk_shape" : tuple (self .chunk_shape )}}
468477
478+ def update_shape (self , new_shape : tuple [int , ...]) -> Self :
479+ return self
480+
469481 def all_chunk_coords (self , array_shape : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
470482 return itertools .product (
471483 * (range (ceildiv (s , c )) for s , c in zip (array_shape , self .chunk_shape , strict = False ))
@@ -645,6 +657,46 @@ def to_dict(self) -> dict[str, JSON]:
645657 },
646658 }
647659
660+ def update_shape (self , new_shape : tuple [int , ...]) -> Self :
661+ """TODO - write docstring"""
662+
663+ if len (new_shape ) != len (self .chunk_shapes ):
664+ raise ValueError (
665+ f"new_shape has { len (new_shape )} dimensions but "
666+ f"chunk_shapes has { len (self .chunk_shapes )} dimensions"
667+ )
668+
669+ new_chunk_shapes : list [tuple [int , ...]] = []
670+ for dim in range (len (new_shape )):
671+ old_dim_length = sum (self .chunk_shapes [dim ])
672+ new_dim_chunks : tuple [int , ...]
673+ if new_shape [dim ] == old_dim_length :
674+ new_dim_chunks = self .chunk_shapes [dim ] # no changes
675+
676+ elif new_shape [dim ] > old_dim_length :
677+ # we have a decision to make on chunk size...
678+ # options:
679+ # - repeat the last chunk size
680+ # - use the size of the new data
681+ # - use some other heuristic
682+ # for now, we'll use the difference in size between the old and new dim length
683+ new_dim_chunks = (* self .chunk_shapes [dim ], new_shape [dim ] - old_dim_length )
684+ else :
685+ # drop chunk sizes that are not inside the shape anymore
686+ total = 0
687+ i = 0
688+ for c in self .chunk_shapes [dim ]:
689+ i += 1
690+ total += c
691+ if total >= new_shape [dim ]:
692+ break
693+ # keep the last chunk (it may be too long)
694+ new_dim_chunks = self .chunk_shapes [dim ][:i ]
695+
696+ new_chunk_shapes .append (new_dim_chunks )
697+
698+ return replace (self , chunk_shapes = tuple (new_chunk_shapes ))
699+
648700 def all_chunk_coords (self , array_shape : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
649701 """
650702 Generate all chunk coordinates for the given array shape.
@@ -664,22 +716,9 @@ def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int,
664716 ValueError
665717 If array_shape doesn't match chunk_shapes
666718 """
667- if len (array_shape ) != len (self .chunk_shapes ):
668- raise ValueError (
669- f"array_shape has { len (array_shape )} dimensions but "
670- f"chunk_shapes has { len (self .chunk_shapes )} dimensions"
671- )
672719
673- # Validate that chunk sizes sum to array shape
674- for axis , (arr_size , axis_chunks ) in enumerate (
675- zip (array_shape , self .chunk_shapes , strict = False )
676- ):
677- chunk_sum = sum (axis_chunks )
678- if chunk_sum != arr_size :
679- raise ValueError (
680- f"Sum of chunk sizes along axis { axis } is { chunk_sum } "
681- f"but array shape is { arr_size } "
682- )
720+ # check array_shape is compatible with chunk grid
721+ self ._validate_array_shape (array_shape )
683722
684723 # Generate coordinates
685724 # For each axis, we have len(axis_chunks) chunks
@@ -705,22 +744,8 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
705744 ValueError
706745 If array_shape doesn't match chunk_shapes
707746 """
708- if len (array_shape ) != len (self .chunk_shapes ):
709- raise ValueError (
710- f"array_shape has { len (array_shape )} dimensions but "
711- f"chunk_shapes has { len (self .chunk_shapes )} dimensions"
712- )
713-
714- # Validate that chunk sizes sum to array shape
715- for axis , (arr_size , axis_chunks ) in enumerate (
716- zip (array_shape , self .chunk_shapes , strict = False )
717- ):
718- chunk_sum = sum (axis_chunks )
719- if chunk_sum != arr_size :
720- raise ValueError (
721- f"Sum of chunk sizes along axis { axis } is { chunk_sum } "
722- f"but array shape is { arr_size } "
723- )
747+ # check array_shape is compatible with chunk grid
748+ self ._validate_array_shape (array_shape )
724749
725750 # Total chunks is the product of number of chunks per axis
726751 return reduce (operator .mul , (len (axis_chunks ) for axis_chunks in self .chunk_shapes ), 1 )
@@ -749,10 +774,11 @@ def _validate_array_shape(self, array_shape: tuple[int, ...]) -> None:
749774 zip (array_shape , self .chunk_shapes , strict = False )
750775 ):
751776 chunk_sum = sum (axis_chunks )
752- if chunk_sum != arr_size :
777+ if chunk_sum < arr_size :
753778 raise ValueError (
754779 f"Sum of chunk sizes along axis { axis } is { chunk_sum } "
755- f"but array shape is { arr_size } "
780+ f"but array shape is { arr_size } . This is invalid for the "
781+ "RectilinearChunkGrid."
756782 )
757783
758784 @cached_property
@@ -816,12 +842,6 @@ def get_chunk_start(
816842 """
817843 self ._validate_array_shape (array_shape )
818844
819- if len (chunk_coord ) != len (self .chunk_shapes ):
820- raise IndexError (
821- f"chunk_coord has { len (chunk_coord )} dimensions but "
822- f"chunk_shapes has { len (self .chunk_shapes )} dimensions"
823- )
824-
825845 # Validate chunk coordinates are in bounds
826846 for axis , (coord , axis_chunks ) in enumerate (
827847 zip (chunk_coord , self .chunk_shapes , strict = False )
@@ -869,12 +889,6 @@ def get_chunk_shape(
869889 """
870890 self ._validate_array_shape (array_shape )
871891
872- if len (chunk_coord ) != len (self .chunk_shapes ):
873- raise IndexError (
874- f"chunk_coord has { len (chunk_coord )} dimensions but "
875- f"chunk_shapes has { len (self .chunk_shapes )} dimensions"
876- )
877-
878892 # Validate chunk coordinates are in bounds
879893 for axis , (coord , axis_chunks ) in enumerate (
880894 zip (chunk_coord , self .chunk_shapes , strict = False )
@@ -994,12 +1008,6 @@ def array_index_to_chunk_coord(
9941008 """
9951009 self ._validate_array_shape (array_shape )
9961010
997- if len (array_index ) != len (array_shape ):
998- raise IndexError (
999- f"array_index has { len (array_index )} dimensions but "
1000- f"array_shape has { len (array_shape )} dimensions"
1001- )
1002-
10031011 # Validate array index is in bounds
10041012 for axis , (idx , size ) in enumerate (zip (array_index , array_shape , strict = False )):
10051013 if not (0 <= idx < size ):
@@ -1047,12 +1055,6 @@ def chunks_in_selection(
10471055 """
10481056 self ._validate_array_shape (array_shape )
10491057
1050- if len (selection ) != len (array_shape ):
1051- raise ValueError (
1052- f"selection has { len (selection )} dimensions but "
1053- f"array_shape has { len (array_shape )} dimensions"
1054- )
1055-
10561058 # Normalize slices and find chunk ranges for each axis
10571059 chunk_ranges = []
10581060 for axis , (sel , size ) in enumerate (zip (selection , array_shape , strict = False )):
@@ -1248,17 +1250,17 @@ def _normalize_rectilinear_chunks(
12481250 # Validate that chunks sum to shape for each dimension
12491251 for i , (dim_chunks , dim_size ) in enumerate (zip (chunk_shapes , shape , strict = False )):
12501252 chunk_sum = sum (dim_chunks )
1251- if chunk_sum != dim_size :
1253+ if chunk_sum < dim_size :
12521254 raise ValueError (
12531255 f"Variable chunks along dimension { i } sum to { chunk_sum } "
1254- f"but array shape is { dim_size } . Chunks must sum exactly to shape."
1256+ f"but array shape is { dim_size } . Chunks must sum to be greater than or equal to the shape."
12551257 )
12561258
12571259 return chunk_shapes
12581260
12591261
12601262def parse_chunk_grid (
1261- chunks : tuple [ int , ...] | Sequence [ Sequence [ int ]] | ChunkGrid | Literal [ "auto" ] | int ,
1263+ chunks : ChunksLike ,
12621264 * ,
12631265 shape : ShapeLike ,
12641266 item_size : int = 1 ,
@@ -1277,7 +1279,7 @@ def parse_chunk_grid(
12771279
12781280 Parameters
12791281 ----------
1280- chunks : tuple[int, ...] | Sequence[Sequence[int]] | ChunkGrid | Literal["auto"] | int
1282+ chunks : ChunksLike
12811283 The chunks parameter to parse. Can be:
12821284 - A ChunkGrid instance
12831285 - A nested sequence for variable-sized chunks (supports RLE format)
@@ -1523,7 +1525,7 @@ def _validate_data_compatibility(
15231525
15241526def resolve_chunk_spec (
15251527 * ,
1526- chunks : tuple [ int , ...] | Sequence [ Sequence [ int ]] | ChunkGrid | Literal [ "auto" ] | int ,
1528+ chunks : ChunksLike ,
15271529 shards : ShardsLike | None ,
15281530 shape : tuple [int , ...],
15291531 dtype_itemsize : int ,
@@ -1542,7 +1544,7 @@ def resolve_chunk_spec(
15421544
15431545 Parameters
15441546 ----------
1545- chunks : tuple[int, ...] | Sequence[Sequence[int]] | ChunkGrid | Literal["auto"] | int
1547+ chunks : ChunksLike
15461548 The chunks specification from the user. Can be:
15471549 - A ChunkGrid instance (Zarr v3 only)
15481550 - A nested sequence for variable-sized chunks (Zarr v3 only)
0 commit comments