1616)
1717from zarr .core .common import ChunkCoords , concurrent_map
1818from zarr .core .config import config
19- from zarr .core .indexing import SelectorTuple , is_scalar , is_total_slice
19+ from zarr .core .indexing import SelectorTuple , is_scalar
2020from zarr .core .metadata .v2 import _default_fill_value
2121from zarr .registry import register_pipeline
2222
@@ -243,18 +243,18 @@ async def encode_partial_batch(
243243
244244 async def read_batch (
245245 self ,
246- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
246+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
247247 out : NDBuffer ,
248248 drop_axes : tuple [int , ...] = (),
249249 ) -> None :
250250 if self .supports_partial_decode :
251251 chunk_array_batch = await self .decode_partial_batch (
252252 [
253253 (byte_getter , chunk_selection , chunk_spec )
254- for byte_getter , chunk_spec , chunk_selection , _ in batch_info
254+ for byte_getter , chunk_spec , chunk_selection , * _ in batch_info
255255 ]
256256 )
257- for chunk_array , (_ , chunk_spec , _ , out_selection ) in zip (
257+ for chunk_array , (_ , chunk_spec , _ , out_selection , _ ) in zip (
258258 chunk_array_batch , batch_info , strict = False
259259 ):
260260 if chunk_array is not None :
@@ -263,22 +263,19 @@ async def read_batch(
263263 out [out_selection ] = fill_value_or_default (chunk_spec )
264264 else :
265265 chunk_bytes_batch = await concurrent_map (
266- [
267- (byte_getter , array_spec .prototype )
268- for byte_getter , array_spec , _ , _ in batch_info
269- ],
266+ [(byte_getter , array_spec .prototype ) for byte_getter , array_spec , * _ in batch_info ],
270267 lambda byte_getter , prototype : byte_getter .get (prototype ),
271268 config .get ("async.concurrency" ),
272269 )
273270 chunk_array_batch = await self .decode_batch (
274271 [
275272 (chunk_bytes , chunk_spec )
276- for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
273+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
277274 chunk_bytes_batch , batch_info , strict = False
278275 )
279276 ],
280277 )
281- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
278+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection , _ ) in zip (
282279 chunk_array_batch , batch_info , strict = False
283280 ):
284281 if chunk_array is not None :
@@ -296,9 +293,10 @@ def _merge_chunk_array(
296293 out_selection : SelectorTuple ,
297294 chunk_spec : ArraySpec ,
298295 chunk_selection : SelectorTuple ,
296+ is_complete_chunk : bool ,
299297 drop_axes : tuple [int , ...],
300298 ) -> NDBuffer :
301- if is_total_slice ( chunk_selection , chunk_spec . shape ) and value .shape == chunk_spec .shape :
299+ if is_complete_chunk and value .shape == chunk_spec .shape :
302300 return value
303301 if existing_chunk_array is None :
304302 chunk_array = chunk_spec .prototype .nd_buffer .create (
@@ -327,7 +325,7 @@ def _merge_chunk_array(
327325
328326 async def write_batch (
329327 self ,
330- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
328+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
331329 value : NDBuffer ,
332330 drop_axes : tuple [int , ...] = (),
333331 ) -> None :
@@ -337,14 +335,14 @@ async def write_batch(
337335 await self .encode_partial_batch (
338336 [
339337 (byte_setter , value , chunk_selection , chunk_spec )
340- for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
338+ for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
341339 ],
342340 )
343341 else :
344342 await self .encode_partial_batch (
345343 [
346344 (byte_setter , value [out_selection ], chunk_selection , chunk_spec )
347- for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
345+ for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
348346 ],
349347 )
350348
@@ -361,33 +359,43 @@ async def _read_key(
361359 chunk_bytes_batch = await concurrent_map (
362360 [
363361 (
364- None if is_total_slice ( chunk_selection , chunk_spec . shape ) else byte_setter ,
362+ None if is_complete_chunk else byte_setter ,
365363 chunk_spec .prototype ,
366364 )
367- for byte_setter , chunk_spec , chunk_selection , _ in batch_info
365+ for byte_setter , chunk_spec , chunk_selection , _ , is_complete_chunk in batch_info
368366 ],
369367 _read_key ,
370368 config .get ("async.concurrency" ),
371369 )
372370 chunk_array_decoded = await self .decode_batch (
373371 [
374372 (chunk_bytes , chunk_spec )
375- for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
373+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
376374 chunk_bytes_batch , batch_info , strict = False
377375 )
378376 ],
379377 )
380378
381379 chunk_array_merged = [
382380 self ._merge_chunk_array (
383- chunk_array , value , out_selection , chunk_spec , chunk_selection , drop_axes
384- )
385- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
386- chunk_array_decoded , batch_info , strict = False
381+ chunk_array ,
382+ value ,
383+ out_selection ,
384+ chunk_spec ,
385+ chunk_selection ,
386+ is_complete_chunk ,
387+ drop_axes ,
387388 )
389+ for chunk_array , (
390+ _ ,
391+ chunk_spec ,
392+ chunk_selection ,
393+ out_selection ,
394+ is_complete_chunk ,
395+ ) in zip (chunk_array_decoded , batch_info , strict = False )
388396 ]
389397 chunk_array_batch : list [NDBuffer | None ] = []
390- for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
398+ for chunk_array , (_ , chunk_spec , * _ ) in zip (
391399 chunk_array_merged , batch_info , strict = False
392400 ):
393401 if chunk_array is None :
@@ -403,7 +411,7 @@ async def _read_key(
403411 chunk_bytes_batch = await self .encode_batch (
404412 [
405413 (chunk_array , chunk_spec )
406- for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
414+ for chunk_array , (_ , chunk_spec , * _ ) in zip (
407415 chunk_array_batch , batch_info , strict = False
408416 )
409417 ],
@@ -418,7 +426,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
418426 await concurrent_map (
419427 [
420428 (byte_setter , chunk_bytes )
421- for chunk_bytes , (byte_setter , _ , _ , _ ) in zip (
429+ for chunk_bytes , (byte_setter , * _ ) in zip (
422430 chunk_bytes_batch , batch_info , strict = False
423431 )
424432 ],
@@ -446,7 +454,7 @@ async def encode(
446454
447455 async def read (
448456 self ,
449- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
457+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
450458 out : NDBuffer ,
451459 drop_axes : tuple [int , ...] = (),
452460 ) -> None :
@@ -461,7 +469,7 @@ async def read(
461469
462470 async def write (
463471 self ,
464- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
472+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
465473 value : NDBuffer ,
466474 drop_axes : tuple [int , ...] = (),
467475 ) -> None :
0 commit comments