16
16
)
17
17
from zarr .core .common import ChunkCoords , concurrent_map
18
18
from zarr .core .config import config
19
- from zarr .core .indexing import SelectorTuple , is_scalar , is_total_slice
19
+ from zarr .core .indexing import SelectorTuple , is_scalar
20
20
from zarr .core .metadata .v2 import _default_fill_value
21
21
from zarr .registry import register_pipeline
22
22
@@ -243,18 +243,18 @@ async def encode_partial_batch(
243
243
244
244
async def read_batch (
245
245
self ,
246
- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
246
+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
247
247
out : NDBuffer ,
248
248
drop_axes : tuple [int , ...] = (),
249
249
) -> None :
250
250
if self .supports_partial_decode :
251
251
chunk_array_batch = await self .decode_partial_batch (
252
252
[
253
253
(byte_getter , chunk_selection , chunk_spec )
254
- for byte_getter , chunk_spec , chunk_selection , _ in batch_info
254
+ for byte_getter , chunk_spec , chunk_selection , * _ in batch_info
255
255
]
256
256
)
257
- for chunk_array , (_ , chunk_spec , _ , out_selection ) in zip (
257
+ for chunk_array , (_ , chunk_spec , _ , out_selection , _ ) in zip (
258
258
chunk_array_batch , batch_info , strict = False
259
259
):
260
260
if chunk_array is not None :
@@ -263,22 +263,19 @@ async def read_batch(
263
263
out [out_selection ] = fill_value_or_default (chunk_spec )
264
264
else :
265
265
chunk_bytes_batch = await concurrent_map (
266
- [
267
- (byte_getter , array_spec .prototype )
268
- for byte_getter , array_spec , _ , _ in batch_info
269
- ],
266
+ [(byte_getter , array_spec .prototype ) for byte_getter , array_spec , * _ in batch_info ],
270
267
lambda byte_getter , prototype : byte_getter .get (prototype ),
271
268
config .get ("async.concurrency" ),
272
269
)
273
270
chunk_array_batch = await self .decode_batch (
274
271
[
275
272
(chunk_bytes , chunk_spec )
276
- for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
273
+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
277
274
chunk_bytes_batch , batch_info , strict = False
278
275
)
279
276
],
280
277
)
281
- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
278
+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection , _ ) in zip (
282
279
chunk_array_batch , batch_info , strict = False
283
280
):
284
281
if chunk_array is not None :
@@ -296,9 +293,10 @@ def _merge_chunk_array(
296
293
out_selection : SelectorTuple ,
297
294
chunk_spec : ArraySpec ,
298
295
chunk_selection : SelectorTuple ,
296
+ is_complete_chunk : bool ,
299
297
drop_axes : tuple [int , ...],
300
298
) -> NDBuffer :
301
- if is_total_slice ( chunk_selection , chunk_spec . shape ) and value .shape == chunk_spec .shape :
299
+ if is_complete_chunk and value .shape == chunk_spec .shape :
302
300
return value
303
301
if existing_chunk_array is None :
304
302
chunk_array = chunk_spec .prototype .nd_buffer .create (
@@ -327,7 +325,7 @@ def _merge_chunk_array(
327
325
328
326
async def write_batch (
329
327
self ,
330
- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
328
+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
331
329
value : NDBuffer ,
332
330
drop_axes : tuple [int , ...] = (),
333
331
) -> None :
@@ -337,14 +335,14 @@ async def write_batch(
337
335
await self .encode_partial_batch (
338
336
[
339
337
(byte_setter , value , chunk_selection , chunk_spec )
340
- for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
338
+ for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
341
339
],
342
340
)
343
341
else :
344
342
await self .encode_partial_batch (
345
343
[
346
344
(byte_setter , value [out_selection ], chunk_selection , chunk_spec )
347
- for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
345
+ for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
348
346
],
349
347
)
350
348
@@ -361,33 +359,43 @@ async def _read_key(
361
359
chunk_bytes_batch = await concurrent_map (
362
360
[
363
361
(
364
- None if is_total_slice ( chunk_selection , chunk_spec . shape ) else byte_setter ,
362
+ None if is_complete_chunk else byte_setter ,
365
363
chunk_spec .prototype ,
366
364
)
367
- for byte_setter , chunk_spec , chunk_selection , _ in batch_info
365
+ for byte_setter , chunk_spec , chunk_selection , _ , is_complete_chunk in batch_info
368
366
],
369
367
_read_key ,
370
368
config .get ("async.concurrency" ),
371
369
)
372
370
chunk_array_decoded = await self .decode_batch (
373
371
[
374
372
(chunk_bytes , chunk_spec )
375
- for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
373
+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
376
374
chunk_bytes_batch , batch_info , strict = False
377
375
)
378
376
],
379
377
)
380
378
381
379
chunk_array_merged = [
382
380
self ._merge_chunk_array (
383
- chunk_array , value , out_selection , chunk_spec , chunk_selection , drop_axes
384
- )
385
- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
386
- chunk_array_decoded , batch_info , strict = False
381
+ chunk_array ,
382
+ value ,
383
+ out_selection ,
384
+ chunk_spec ,
385
+ chunk_selection ,
386
+ is_complete_chunk ,
387
+ drop_axes ,
387
388
)
389
+ for chunk_array , (
390
+ _ ,
391
+ chunk_spec ,
392
+ chunk_selection ,
393
+ out_selection ,
394
+ is_complete_chunk ,
395
+ ) in zip (chunk_array_decoded , batch_info , strict = False )
388
396
]
389
397
chunk_array_batch : list [NDBuffer | None ] = []
390
- for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
398
+ for chunk_array , (_ , chunk_spec , * _ ) in zip (
391
399
chunk_array_merged , batch_info , strict = False
392
400
):
393
401
if chunk_array is None :
@@ -403,7 +411,7 @@ async def _read_key(
403
411
chunk_bytes_batch = await self .encode_batch (
404
412
[
405
413
(chunk_array , chunk_spec )
406
- for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
414
+ for chunk_array , (_ , chunk_spec , * _ ) in zip (
407
415
chunk_array_batch , batch_info , strict = False
408
416
)
409
417
],
@@ -418,7 +426,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
418
426
await concurrent_map (
419
427
[
420
428
(byte_setter , chunk_bytes )
421
- for chunk_bytes , (byte_setter , _ , _ , _ ) in zip (
429
+ for chunk_bytes , (byte_setter , * _ ) in zip (
422
430
chunk_bytes_batch , batch_info , strict = False
423
431
)
424
432
],
@@ -446,7 +454,7 @@ async def encode(
446
454
447
455
async def read (
448
456
self ,
449
- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
457
+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
450
458
out : NDBuffer ,
451
459
drop_axes : tuple [int , ...] = (),
452
460
) -> None :
@@ -461,7 +469,7 @@ async def read(
461
469
462
470
async def write (
463
471
self ,
464
- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
472
+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
465
473
value : NDBuffer ,
466
474
drop_axes : tuple [int , ...] = (),
467
475
) -> None :
0 commit comments