1616)
1717from zarr .core .common import ChunkCoords , concurrent_map
1818from zarr .core .config import config
19- from zarr .core .indexing import SelectorTuple , is_scalar
19+ from zarr .core .indexing import SelectorTuple , is_scalar , is_total_slice
2020from zarr .core .metadata .v2 import _default_fill_value
2121from zarr .registry import register_pipeline
2222
@@ -243,18 +243,18 @@ async def encode_partial_batch(
243243
244244 async def read_batch (
245245 self ,
246- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
246+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
247247 out : NDBuffer ,
248248 drop_axes : tuple [int , ...] = (),
249249 ) -> None :
250250 if self .supports_partial_decode :
251251 chunk_array_batch = await self .decode_partial_batch (
252252 [
253253 (byte_getter , chunk_selection , chunk_spec )
254- for byte_getter , chunk_spec , chunk_selection , * _ in batch_info
254+ for byte_getter , chunk_spec , chunk_selection , _ in batch_info
255255 ]
256256 )
257- for chunk_array , (_ , chunk_spec , _ , out_selection , _ ) in zip (
257+ for chunk_array , (_ , chunk_spec , _ , out_selection ) in zip (
258258 chunk_array_batch , batch_info , strict = False
259259 ):
260260 if chunk_array is not None :
@@ -263,19 +263,22 @@ async def read_batch(
263263 out [out_selection ] = fill_value_or_default (chunk_spec )
264264 else :
265265 chunk_bytes_batch = await concurrent_map (
266- [(byte_getter , array_spec .prototype ) for byte_getter , array_spec , * _ in batch_info ],
266+ [
267+ (byte_getter , array_spec .prototype )
268+ for byte_getter , array_spec , _ , _ in batch_info
269+ ],
267270 lambda byte_getter , prototype : byte_getter .get (prototype ),
268271 config .get ("async.concurrency" ),
269272 )
270273 chunk_array_batch = await self .decode_batch (
271274 [
272275 (chunk_bytes , chunk_spec )
273- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
276+ for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
274277 chunk_bytes_batch , batch_info , strict = False
275278 )
276279 ],
277280 )
278- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection , _ ) in zip (
281+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
279282 chunk_array_batch , batch_info , strict = False
280283 ):
281284 if chunk_array is not None :
@@ -293,10 +296,9 @@ def _merge_chunk_array(
293296 out_selection : SelectorTuple ,
294297 chunk_spec : ArraySpec ,
295298 chunk_selection : SelectorTuple ,
296- is_complete_chunk : bool ,
297299 drop_axes : tuple [int , ...],
298300 ) -> NDBuffer :
299- if is_complete_chunk and value .shape == chunk_spec .shape :
301+ if is_total_slice ( chunk_selection , chunk_spec . shape ) and value .shape == chunk_spec .shape :
300302 return value
301303 if existing_chunk_array is None :
302304 chunk_array = chunk_spec .prototype .nd_buffer .create (
@@ -325,7 +327,7 @@ def _merge_chunk_array(
325327
326328 async def write_batch (
327329 self ,
328- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
330+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
329331 value : NDBuffer ,
330332 drop_axes : tuple [int , ...] = (),
331333 ) -> None :
@@ -335,14 +337,14 @@ async def write_batch(
335337 await self .encode_partial_batch (
336338 [
337339 (byte_setter , value , chunk_selection , chunk_spec )
338- for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
340+ for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
339341 ],
340342 )
341343 else :
342344 await self .encode_partial_batch (
343345 [
344346 (byte_setter , value [out_selection ], chunk_selection , chunk_spec )
345- for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
347+ for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
346348 ],
347349 )
348350
@@ -359,43 +361,33 @@ async def _read_key(
359361 chunk_bytes_batch = await concurrent_map (
360362 [
361363 (
362- None if is_complete_chunk else byte_setter ,
364+ None if is_total_slice ( chunk_selection , chunk_spec . shape ) else byte_setter ,
363365 chunk_spec .prototype ,
364366 )
365- for byte_setter , chunk_spec , chunk_selection , _ , is_complete_chunk in batch_info
367+ for byte_setter , chunk_spec , chunk_selection , _ in batch_info
366368 ],
367369 _read_key ,
368370 config .get ("async.concurrency" ),
369371 )
370372 chunk_array_decoded = await self .decode_batch (
371373 [
372374 (chunk_bytes , chunk_spec )
373- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
375+ for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
374376 chunk_bytes_batch , batch_info , strict = False
375377 )
376378 ],
377379 )
378380
379381 chunk_array_merged = [
380382 self ._merge_chunk_array (
381- chunk_array ,
382- value ,
383- out_selection ,
384- chunk_spec ,
385- chunk_selection ,
386- is_complete_chunk ,
387- drop_axes ,
383+ chunk_array , value , out_selection , chunk_spec , chunk_selection , drop_axes
384+ )
385+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
386+ chunk_array_decoded , batch_info , strict = False
388387 )
389- for chunk_array , (
390- _ ,
391- chunk_spec ,
392- chunk_selection ,
393- out_selection ,
394- is_complete_chunk ,
395- ) in zip (chunk_array_decoded , batch_info , strict = False )
396388 ]
397389 chunk_array_batch : list [NDBuffer | None ] = []
398- for chunk_array , (_ , chunk_spec , * _ ) in zip (
390+ for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
399391 chunk_array_merged , batch_info , strict = False
400392 ):
401393 if chunk_array is None :
@@ -411,7 +403,7 @@ async def _read_key(
411403 chunk_bytes_batch = await self .encode_batch (
412404 [
413405 (chunk_array , chunk_spec )
414- for chunk_array , (_ , chunk_spec , * _ ) in zip (
406+ for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
415407 chunk_array_batch , batch_info , strict = False
416408 )
417409 ],
@@ -426,7 +418,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
426418 await concurrent_map (
427419 [
428420 (byte_setter , chunk_bytes )
429- for chunk_bytes , (byte_setter , * _ ) in zip (
421+ for chunk_bytes , (byte_setter , _ , _ , _ ) in zip (
430422 chunk_bytes_batch , batch_info , strict = False
431423 )
432424 ],
@@ -454,7 +446,7 @@ async def encode(
454446
455447 async def read (
456448 self ,
457- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
449+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
458450 out : NDBuffer ,
459451 drop_axes : tuple [int , ...] = (),
460452 ) -> None :
@@ -469,7 +461,7 @@ async def read(
469461
470462 async def write (
471463 self ,
472- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
464+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
473465 value : NDBuffer ,
474466 drop_axes : tuple [int , ...] = (),
475467 ) -> None :
0 commit comments