@@ -90,9 +90,9 @@ async def get(
9090 self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
9191 ) -> Buffer | None :
9292 assert byte_range is None , "byte_range is not supported within shards"
93- assert prototype == default_buffer_prototype (), (
94- f" prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype ()} "
95- )
93+ assert (
94+ prototype == default_buffer_prototype ()
95+ ), f"prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype () } "
9696 return self .shard_dict .get (self .chunk_coords )
9797
9898
@@ -124,7 +124,9 @@ def chunks_per_shard(self) -> ChunkCoords:
124124 def _localize_chunk (self , chunk_coords : ChunkCoords ) -> ChunkCoords :
125125 return tuple (
126126 chunk_i % shard_i
127- for chunk_i , shard_i in zip (chunk_coords , self .offsets_and_lengths .shape , strict = False )
127+ for chunk_i , shard_i in zip (
128+ chunk_coords , self .offsets_and_lengths .shape , strict = False
129+ )
128130 )
129131
130132 def is_all_empty (self ) -> bool :
@@ -141,7 +143,9 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
141143 else :
142144 return (int (chunk_start ), int (chunk_start + chunk_len ))
143145
144- def set_chunk_slice (self , chunk_coords : ChunkCoords , chunk_slice : slice | None ) -> None :
146+ def set_chunk_slice (
147+ self , chunk_coords : ChunkCoords , chunk_slice : slice | None
148+ ) -> None :
145149 localized_chunk = self ._localize_chunk (chunk_coords )
146150 if chunk_slice is None :
147151 self .offsets_and_lengths [localized_chunk ] = (MAX_UINT_64 , MAX_UINT_64 )
@@ -163,7 +167,11 @@ def is_dense(self, chunk_byte_length: int) -> bool:
163167
164168 # Are all non-empty offsets unique?
165169 if len (
166- {offset for offset , _ in sorted_offsets_and_lengths if offset != MAX_UINT_64 }
170+ {
171+ offset
172+ for offset , _ in sorted_offsets_and_lengths
173+ if offset != MAX_UINT_64
174+ }
167175 ) != len (sorted_offsets_and_lengths ):
168176 return False
169177
@@ -267,7 +275,9 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None:
267275 chunk_start = len (self .buf )
268276 chunk_length = len (value )
269277 self .buf += value
270- self .index .set_chunk_slice (chunk_coords , slice (chunk_start , chunk_start + chunk_length ))
278+ self .index .set_chunk_slice (
279+ chunk_coords , slice (chunk_start , chunk_start + chunk_length )
280+ )
271281
272282 def __delitem__ (self , chunk_coords : ChunkCoords ) -> None :
273283 raise NotImplementedError
@@ -281,7 +291,9 @@ async def finalize(
281291 if index_location == ShardingCodecIndexLocation .start :
282292 empty_chunks_mask = self .index .offsets_and_lengths [..., 0 ] == MAX_UINT_64
283293 self .index .offsets_and_lengths [~ empty_chunks_mask , 0 ] += len (index_bytes )
284- index_bytes = await index_encoder (self .index ) # encode again with corrected offsets
294+ index_bytes = await index_encoder (
295+ self .index
296+ ) # encode again with corrected offsets
285297 out_buf = index_bytes + self .buf
286298 else :
287299 out_buf = self .buf + index_bytes
@@ -359,7 +371,8 @@ def __init__(
359371 chunk_shape : ChunkCoordsLike ,
360372 codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (),),
361373 index_codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (), Crc32cCodec ()),
362- index_location : ShardingCodecIndexLocation | str = ShardingCodecIndexLocation .end ,
374+ index_location : ShardingCodecIndexLocation
375+ | str = ShardingCodecIndexLocation .end ,
363376 ) -> None :
364377 chunk_shape_parsed = parse_shapelike (chunk_shape )
365378 codecs_parsed = parse_codecs (codecs )
@@ -389,7 +402,9 @@ def __setstate__(self, state: dict[str, Any]) -> None:
389402 object .__setattr__ (self , "chunk_shape" , parse_shapelike (config ["chunk_shape" ]))
390403 object .__setattr__ (self , "codecs" , parse_codecs (config ["codecs" ]))
391404 object .__setattr__ (self , "index_codecs" , parse_codecs (config ["index_codecs" ]))
392- object .__setattr__ (self , "index_location" , parse_index_location (config ["index_location" ]))
405+ object .__setattr__ (
406+ self , "index_location" , parse_index_location (config ["index_location" ])
407+ )
393408
394409 # Use instance-local lru_cache to avoid memory leaks
395410 # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -418,7 +433,9 @@ def to_dict(self) -> dict[str, JSON]:
418433
419434 def evolve_from_array_spec (self , array_spec : ArraySpec ) -> Self :
420435 shard_spec = self ._get_chunk_spec (array_spec )
421- evolved_codecs = tuple (c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs )
436+ evolved_codecs = tuple (
437+ c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs
438+ )
422439 if evolved_codecs != self .codecs :
423440 return replace (self , codecs = evolved_codecs )
424441 return self
@@ -469,7 +486,7 @@ async def _decode_single(
469486 shape = shard_shape ,
470487 dtype = shard_spec .dtype .to_native_dtype (),
471488 order = shard_spec .order ,
472- fill_value = 0 ,
489+ fill_value = shard_spec . fill_value ,
473490 )
474491 shard_dict = await _ShardReader .from_bytes (shard_bytes , self , chunks_per_shard )
475492
@@ -516,7 +533,7 @@ async def _decode_partial_single(
516533 shape = indexer .shape ,
517534 dtype = shard_spec .dtype .to_native_dtype (),
518535 order = shard_spec .order ,
519- fill_value = 0 ,
536+ fill_value = shard_spec . fill_value ,
520537 )
521538
522539 indexed_chunks = list (indexer )
@@ -593,7 +610,9 @@ async def _encode_single(
593610 shard_array ,
594611 )
595612
596- return await shard_builder .finalize (self .index_location , self ._encode_shard_index )
613+ return await shard_builder .finalize (
614+ self .index_location , self ._encode_shard_index
615+ )
597616
598617 async def _encode_partial_single (
599618 self ,
@@ -653,7 +672,8 @@ def _is_total_shard(
653672 self , all_chunk_coords : set [ChunkCoords ], chunks_per_shard : ChunkCoords
654673 ) -> bool :
655674 return len (all_chunk_coords ) == product (chunks_per_shard ) and all (
656- chunk_coords in all_chunk_coords for chunk_coords in c_order_iter (chunks_per_shard )
675+ chunk_coords in all_chunk_coords
676+ for chunk_coords in c_order_iter (chunks_per_shard )
657677 )
658678
659679 async def _decode_shard_index (
@@ -679,7 +699,9 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
679699 .encode (
680700 [
681701 (
682- get_ndbuffer_class ().from_numpy_array (index .offsets_and_lengths ),
702+ get_ndbuffer_class ().from_numpy_array (
703+ index .offsets_and_lengths
704+ ),
683705 self ._get_index_chunk_spec (index .chunks_per_shard ),
684706 )
685707 ],
@@ -790,8 +812,8 @@ async def _load_partial_shard_maybe(
790812 # Drop chunks where index lookup fails
791813 if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
792814 ]
793- if len (chunks ) == 0 :
794- return {}
815+ if len (chunks ) < len ( all_chunk_coords ) :
816+ return None
795817
796818 groups = self ._coalesce_chunks (chunks )
797819
@@ -803,6 +825,8 @@ async def _load_partial_shard_maybe(
803825
804826 shard_dict : ShardMutableMapping = {}
805827 for d in shard_dicts :
828+ if d is None :
829+ return None
806830 shard_dict .update (d )
807831
808832 return shard_dict
@@ -830,7 +854,9 @@ def _coalesce_chunks(
830854
831855 for chunk in sorted_chunks [1 :]:
832856 gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
833- size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
857+ size_if_coalesced = (
858+ chunk .byte_slice .stop - current_group [0 ].byte_slice .start
859+ )
834860 if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
835861 current_group .append (chunk )
836862 else :
@@ -846,7 +872,7 @@ async def _get_group_bytes(
846872 group : list [_ChunkCoordsByteSlice ],
847873 byte_getter : ByteGetter ,
848874 prototype : BufferPrototype ,
849- ) -> ShardMapping :
875+ ) -> ShardMapping | None :
850876 """
851877 Reads a possibly coalesced group of one or more chunks from a shard.
852878 Returns a mapping of chunk coordinates to bytes.
@@ -860,7 +886,7 @@ async def _get_group_bytes(
860886 byte_range = RangeByteRequest (group_start , group_end ),
861887 )
862888 if group_bytes is None :
863- return {}
889+ return None
864890
865891 # Extract the bytes corresponding to each chunk in group from group_bytes.
866892 shard_dict = {}
@@ -873,7 +899,9 @@ async def _get_group_bytes(
873899
874900 return shard_dict
875901
876- def compute_encoded_size (self , input_byte_length : int , shard_spec : ArraySpec ) -> int :
902+ def compute_encoded_size (
903+ self , input_byte_length : int , shard_spec : ArraySpec
904+ ) -> int :
877905 chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
878906 return input_byte_length + self ._shard_index_size (chunks_per_shard )
879907
0 commit comments