3838from zarr .core .common import (
3939 ChunkCoords ,
4040 ChunkCoordsLike ,
41+ concurrent_map ,
4142 parse_enum ,
4243 parse_named_configuration ,
4344 parse_shapelike ,
4445 product ,
4546)
47+ from zarr .core .config import config
4648from zarr .core .indexing import (
4749 BasicIndexer ,
4850 SelectorTuple ,
@@ -327,6 +329,11 @@ async def finalize(
327329 return await shard_builder .finalize (index_location , index_encoder )
328330
329331
332+ class _ChunkCoordsByteSlice (NamedTuple ):
333+ coords : ChunkCoords
334+ byte_slice : slice
335+
336+
330337@dataclass (frozen = True )
331338class ShardingCodec (
332339 ArrayBytesCodec , ArrayBytesCodecPartialDecodeMixin , ArrayBytesCodecPartialEncodeMixin
@@ -490,32 +497,21 @@ async def _decode_partial_single(
490497 all_chunk_coords = {chunk_coords for chunk_coords , * _ in indexed_chunks }
491498
492499 # reading bytes of all requested chunks
493- shard_dict : ShardMapping = {}
500+ shard_dict_maybe : ShardMapping | None = {}
494501 if self ._is_total_shard (all_chunk_coords , chunks_per_shard ):
495502 # read entire shard
496503 shard_dict_maybe = await self ._load_full_shard_maybe (
497- byte_getter = byte_getter ,
498- prototype = chunk_spec .prototype ,
499- chunks_per_shard = chunks_per_shard ,
504+ byte_getter , chunk_spec .prototype , chunks_per_shard
500505 )
501- if shard_dict_maybe is None :
502- return None
503- shard_dict = shard_dict_maybe
504506 else :
505507 # read some chunks within the shard
506- shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
507- if shard_index is None :
508- return None
509- shard_dict = {}
510- for chunk_coords in all_chunk_coords :
511- chunk_byte_slice = shard_index .get_chunk_slice (chunk_coords )
512- if chunk_byte_slice :
513- chunk_bytes = await byte_getter .get (
514- prototype = chunk_spec .prototype ,
515- byte_range = RangeByteRequest (chunk_byte_slice [0 ], chunk_byte_slice [1 ]),
516- )
517- if chunk_bytes :
518- shard_dict [chunk_coords ] = chunk_bytes
508+ shard_dict_maybe = await self ._load_partial_shard_maybe (
509+ byte_getter , chunk_spec .prototype , chunks_per_shard , all_chunk_coords
510+ )
511+
512+ if shard_dict_maybe is None :
513+ return None
514+ shard_dict = shard_dict_maybe
519515
520516 # decoding chunks and writing them into the output buffer
521517 await self .codec_pipeline .read (
@@ -537,6 +533,96 @@ async def _decode_partial_single(
537533 else :
538534 return out
539535
536+ async def _load_partial_shard_maybe (
537+ self ,
538+ byte_getter : ByteGetter ,
539+ prototype : BufferPrototype ,
540+ chunks_per_shard : ChunkCoords ,
541+ all_chunk_coords : set [ChunkCoords ],
542+ ) -> ShardMapping | None :
543+ shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
544+ if shard_index is None :
545+ return None
546+
547+ chunks = [
548+ _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
549+ for chunk_coords in all_chunk_coords
550+ if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
551+ ]
552+ if len (chunks ) == 0 :
553+ return {}
554+
555+ groups = self ._coalesce_chunks (chunks )
556+
557+ shard_dicts = await concurrent_map (
558+ [(group , byte_getter , prototype ) for group in groups ],
559+ self ._get_group_bytes ,
560+ config .get ("async.concurrency" ),
561+ )
562+
563+ shard_dict : ShardMutableMapping = {}
564+ for d in shard_dicts :
565+ shard_dict .update (d )
566+
567+ return shard_dict
568+
569+ def _coalesce_chunks (
570+ self ,
571+ chunks : list [_ChunkCoordsByteSlice ],
572+ max_gap_bytes : int = 2 ** 20 , # 1MiB
573+ coalesce_max_bytes : int = 100 * 2 ** 20 , # 100MiB
574+ ) -> list [list [_ChunkCoordsByteSlice ]]:
575+ sorted_chunks = sorted (chunks , key = lambda c : c .byte_slice .start )
576+
577+ groups = []
578+ current_group = [sorted_chunks [0 ]]
579+
580+ for chunk in sorted_chunks [1 :]:
581+ gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
582+ current_group_size = (
583+ current_group [- 1 ].byte_slice .stop - current_group [0 ].byte_slice .start
584+ )
585+ if gap_to_chunk < max_gap_bytes and current_group_size < coalesce_max_bytes :
586+ current_group .append (chunk )
587+ else :
588+ groups .append (current_group )
589+ current_group = [chunk ]
590+
591+ groups .append (current_group )
592+
593+ from pprint import pprint
594+
595+ pprint (
596+ [
597+ f"{ len (g )} chunks, { (g [- 1 ].byte_slice .stop - g [0 ].byte_slice .start ) / 1e6 :.1f} MB"
598+ for g in groups
599+ ]
600+ )
601+
602+ return groups
603+
604+ async def _get_group_bytes (
605+ self ,
606+ group : list [_ChunkCoordsByteSlice ],
607+ byte_getter : ByteGetter ,
608+ prototype : BufferPrototype ,
609+ ) -> ShardMapping :
610+ group_start = group [0 ].byte_slice .start
611+ group_end = group [- 1 ].byte_slice .stop
612+
613+ group_bytes = await byte_getter .get (
614+ prototype = prototype ,
615+ byte_range = RangeByteRequest (group_start , group_end ),
616+ )
617+ if group_bytes is None :
618+ return {}
619+
620+ shard_dict = {}
621+ for chunk in group :
622+ s = slice (chunk .byte_slice .start - group_start , chunk .byte_slice .stop - group_start )
623+ shard_dict [chunk .coords ] = group_bytes [s ]
624+ return shard_dict
625+
540626 async def _encode_single (
541627 self ,
542628 shard_array : NDBuffer ,
0 commit comments