Skip to content

Commit 0e00eb3

Browse files
committed
Different approach
1 parent 1d11fcb commit 0e00eb3

File tree

6 files changed

+69
-83
lines changed

6 files changed

+69
-83
lines changed

src/zarr/abc/codec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ async def encode(
357357
@abstractmethod
358358
async def read(
359359
self,
360-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
360+
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
361361
out: NDBuffer,
362362
drop_axes: tuple[int, ...] = (),
363363
) -> None:
@@ -379,7 +379,7 @@ async def read(
379379
@abstractmethod
380380
async def write(
381381
self,
382-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
382+
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
383383
value: NDBuffer,
384384
drop_axes: tuple[int, ...] = (),
385385
) -> None:

src/zarr/codecs/sharding.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,9 @@ async def _decode_single(
455455
chunk_spec,
456456
chunk_selection,
457457
out_selection,
458+
is_complete_shard,
458459
)
459-
for chunk_coords, chunk_selection, out_selection in indexer
460+
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
460461
],
461462
out,
462463
)
@@ -486,7 +487,7 @@ async def _decode_partial_single(
486487
)
487488

488489
indexed_chunks = list(indexer)
489-
all_chunk_coords = {chunk_coords for chunk_coords, _, _ in indexed_chunks}
490+
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
490491

491492
# reading bytes of all requested chunks
492493
shard_dict: ShardMapping = {}
@@ -524,8 +525,9 @@ async def _decode_partial_single(
524525
chunk_spec,
525526
chunk_selection,
526527
out_selection,
528+
is_complete_shard,
527529
)
528-
for chunk_coords, chunk_selection, out_selection in indexer
530+
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
529531
],
530532
out,
531533
)
@@ -558,8 +560,9 @@ async def _encode_single(
558560
chunk_spec,
559561
chunk_selection,
560562
out_selection,
563+
is_complete_shard,
561564
)
562-
for chunk_coords, chunk_selection, out_selection in indexer
565+
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
563566
],
564567
shard_array,
565568
)
@@ -601,8 +604,9 @@ async def _encode_partial_single(
601604
chunk_spec,
602605
chunk_selection,
603606
out_selection,
607+
is_complete_shard,
604608
)
605-
for chunk_coords, chunk_selection, out_selection in indexer
609+
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
606610
],
607611
shard_array,
608612
)

src/zarr/core/array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,8 +1290,9 @@ async def _get_selection(
12901290
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
12911291
chunk_selection,
12921292
out_selection,
1293+
is_complete_chunk,
12931294
)
1294-
for chunk_coords, chunk_selection, out_selection in indexer
1295+
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
12951296
],
12961297
out_buffer,
12971298
drop_axes=indexer.drop_axes,
@@ -1417,8 +1418,9 @@ async def _set_selection(
14171418
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
14181419
chunk_selection,
14191420
out_selection,
1421+
is_complete_chunk,
14201422
)
1421-
for chunk_coords, chunk_selection, out_selection in indexer
1423+
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
14221424
],
14231425
value_buffer,
14241426
drop_axes=indexer.drop_axes,

src/zarr/core/codec_pipeline.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from zarr.core.common import ChunkCoords, concurrent_map
1818
from zarr.core.config import config
19-
from zarr.core.indexing import SelectorTuple, is_scalar, is_total_slice
19+
from zarr.core.indexing import SelectorTuple, is_scalar
2020
from zarr.core.metadata.v2 import _default_fill_value
2121
from zarr.registry import register_pipeline
2222

@@ -230,18 +230,18 @@ async def encode_partial_batch(
230230

231231
async def read_batch(
232232
self,
233-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
233+
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
234234
out: NDBuffer,
235235
drop_axes: tuple[int, ...] = (),
236236
) -> None:
237237
if self.supports_partial_decode:
238238
chunk_array_batch = await self.decode_partial_batch(
239239
[
240240
(byte_getter, chunk_selection, chunk_spec)
241-
for byte_getter, chunk_spec, chunk_selection, _ in batch_info
241+
for byte_getter, chunk_spec, chunk_selection, *_ in batch_info
242242
]
243243
)
244-
for chunk_array, (_, chunk_spec, _, out_selection) in zip(
244+
for chunk_array, (_, chunk_spec, _, out_selection, _) in zip(
245245
chunk_array_batch, batch_info, strict=False
246246
):
247247
if chunk_array is not None:
@@ -260,22 +260,19 @@ async def read_batch(
260260
out[out_selection] = fill_value
261261
else:
262262
chunk_bytes_batch = await concurrent_map(
263-
[
264-
(byte_getter, array_spec.prototype)
265-
for byte_getter, array_spec, _, _ in batch_info
266-
],
263+
[(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info],
267264
lambda byte_getter, prototype: byte_getter.get(prototype),
268265
config.get("async.concurrency"),
269266
)
270267
chunk_array_batch = await self.decode_batch(
271268
[
272269
(chunk_bytes, chunk_spec)
273-
for chunk_bytes, (_, chunk_spec, _, _) in zip(
270+
for chunk_bytes, (_, chunk_spec, *_) in zip(
274271
chunk_bytes_batch, batch_info, strict=False
275272
)
276273
],
277274
)
278-
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
275+
for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip(
279276
chunk_array_batch, batch_info, strict=False
280277
):
281278
if chunk_array is not None:
@@ -296,9 +293,10 @@ def _merge_chunk_array(
296293
out_selection: SelectorTuple,
297294
chunk_spec: ArraySpec,
298295
chunk_selection: SelectorTuple,
296+
is_complete_chunk: bool,
299297
drop_axes: tuple[int, ...],
300298
) -> NDBuffer:
301-
if is_total_slice(chunk_selection, chunk_spec.shape) and value.shape == chunk_spec.shape:
299+
if is_complete_chunk and value.shape == chunk_spec.shape:
302300
return value
303301
if existing_chunk_array is None:
304302
chunk_array = chunk_spec.prototype.nd_buffer.create(
@@ -327,7 +325,7 @@ def _merge_chunk_array(
327325

328326
async def write_batch(
329327
self,
330-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
328+
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
331329
value: NDBuffer,
332330
drop_axes: tuple[int, ...] = (),
333331
) -> None:
@@ -337,14 +335,14 @@ async def write_batch(
337335
await self.encode_partial_batch(
338336
[
339337
(byte_setter, value, chunk_selection, chunk_spec)
340-
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
338+
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
341339
],
342340
)
343341
else:
344342
await self.encode_partial_batch(
345343
[
346344
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
347-
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
345+
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
348346
],
349347
)
350348

@@ -361,33 +359,43 @@ async def _read_key(
361359
chunk_bytes_batch = await concurrent_map(
362360
[
363361
(
364-
None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,
362+
None if is_complete_chunk else byte_setter,
365363
chunk_spec.prototype,
366364
)
367-
for byte_setter, chunk_spec, chunk_selection, _ in batch_info
365+
for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info
368366
],
369367
_read_key,
370368
config.get("async.concurrency"),
371369
)
372370
chunk_array_decoded = await self.decode_batch(
373371
[
374372
(chunk_bytes, chunk_spec)
375-
for chunk_bytes, (_, chunk_spec, _, _) in zip(
373+
for chunk_bytes, (_, chunk_spec, *_) in zip(
376374
chunk_bytes_batch, batch_info, strict=False
377375
)
378376
],
379377
)
380378

381379
chunk_array_merged = [
382380
self._merge_chunk_array(
383-
chunk_array, value, out_selection, chunk_spec, chunk_selection, drop_axes
384-
)
385-
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
386-
chunk_array_decoded, batch_info, strict=False
381+
chunk_array,
382+
value,
383+
out_selection,
384+
chunk_spec,
385+
chunk_selection,
386+
is_complete_chunk,
387+
drop_axes,
387388
)
389+
for chunk_array, (
390+
_,
391+
chunk_spec,
392+
chunk_selection,
393+
out_selection,
394+
is_complete_chunk,
395+
) in zip(chunk_array_decoded, batch_info, strict=False)
388396
]
389397
chunk_array_batch: list[NDBuffer | None] = []
390-
for chunk_array, (_, chunk_spec, _, _) in zip(
398+
for chunk_array, (_, chunk_spec, *_) in zip(
391399
chunk_array_merged, batch_info, strict=False
392400
):
393401
if chunk_array is None:
@@ -403,7 +411,7 @@ async def _read_key(
403411
chunk_bytes_batch = await self.encode_batch(
404412
[
405413
(chunk_array, chunk_spec)
406-
for chunk_array, (_, chunk_spec, _, _) in zip(
414+
for chunk_array, (_, chunk_spec, *_) in zip(
407415
chunk_array_batch, batch_info, strict=False
408416
)
409417
],
@@ -418,7 +426,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
418426
await concurrent_map(
419427
[
420428
(byte_setter, chunk_bytes)
421-
for chunk_bytes, (byte_setter, _, _, _) in zip(
429+
for chunk_bytes, (byte_setter, *_) in zip(
422430
chunk_bytes_batch, batch_info, strict=False
423431
)
424432
],
@@ -446,7 +454,7 @@ async def encode(
446454

447455
async def read(
448456
self,
449-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
457+
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
450458
out: NDBuffer,
451459
drop_axes: tuple[int, ...] = (),
452460
) -> None:
@@ -461,7 +469,7 @@ async def read(
461469

462470
async def write(
463471
self,
464-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
472+
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
465473
value: NDBuffer,
466474
drop_axes: tuple[int, ...] = (),
467475
) -> None:

0 commit comments

Comments
 (0)