1717from zarr .core .buffer .cpu import Buffer
1818
1919if TYPE_CHECKING :
20- from zarr .abc .store import RangeByteRequest
20+ from zarr .abc .store import ByteRequest
2121 from zarr .core .buffer import BufferPrototype
2222
2323
@@ -349,16 +349,18 @@ class MockByteGetter:
349349 return_none : bool = False
350350
351351 async def get (
352- self , prototype : BufferPrototype , byte_range : RangeByteRequest | None = None
352+ self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
353353 ) -> Buffer | None :
354354 if self .return_none :
355355 return None
356356 if byte_range is None :
357357 return Buffer .from_bytes (self .data )
358- return Buffer .from_bytes (self .data [byte_range .start : byte_range .end ])
358+ # For RangeByteRequest, extract start and end
359+ start = getattr (byte_range , "start" , 0 )
360+ end = getattr (byte_range , "end" , len (self .data ))
361+ return Buffer .from_bytes (self .data [start :end ])
359362
360363
361- @pytest .mark .asyncio
362364async def test_get_group_bytes_single_chunk () -> None :
363365 """Test _get_group_bytes extracts single chunk correctly."""
364366 codec = ShardingCodec (chunk_shape = (8 ,))
@@ -372,10 +374,11 @@ async def test_get_group_bytes_single_chunk() -> None:
372374
373375 assert result is not None
374376 assert (0 ,) in result
375- assert result [(0 ,)].as_numpy_array ().tobytes () == data [10 :30 ]
377+ chunk_buf = result [(0 ,)]
378+ assert chunk_buf is not None
379+ assert chunk_buf .as_numpy_array ().tobytes () == data [10 :30 ]
376380
377381
378- @pytest .mark .asyncio
379382async def test_get_group_bytes_multiple_chunks () -> None :
380383 """Test _get_group_bytes extracts multiple chunks with correct offsets."""
381384 codec = ShardingCodec (chunk_shape = (8 ,))
@@ -391,11 +394,14 @@ async def test_get_group_bytes_multiple_chunks() -> None:
391394
392395 assert result is not None
393396 assert len (result ) == 2
394- assert result [(0 ,)].as_numpy_array ().tobytes () == data [10 :30 ]
395- assert result [(1 ,)].as_numpy_array ().tobytes () == data [30 :50 ]
397+ chunk0_buf = result [(0 ,)]
398+ chunk1_buf = result [(1 ,)]
399+ assert chunk0_buf is not None
400+ assert chunk1_buf is not None
401+ assert chunk0_buf .as_numpy_array ().tobytes () == data [10 :30 ]
402+ assert chunk1_buf .as_numpy_array ().tobytes () == data [30 :50 ]
396403
397404
398- @pytest .mark .asyncio
399405async def test_get_group_bytes_with_gap () -> None :
400406 """Test _get_group_bytes handles chunks with gaps correctly."""
401407 codec = ShardingCodec (chunk_shape = (8 ,))
@@ -412,11 +418,14 @@ async def test_get_group_bytes_with_gap() -> None:
412418 assert result is not None
413419 assert len (result ) == 2
414420 # The byte_getter.get is called with range [10, 60), then sliced
415- assert result [(0 ,)].as_numpy_array ().tobytes () == data [10 :20 ]
416- assert result [(1 ,)].as_numpy_array ().tobytes () == data [40 :60 ]
421+ chunk0_buf = result [(0 ,)]
422+ chunk1_buf = result [(1 ,)]
423+ assert chunk0_buf is not None
424+ assert chunk1_buf is not None
425+ assert chunk0_buf .as_numpy_array ().tobytes () == data [10 :20 ]
426+ assert chunk1_buf .as_numpy_array ().tobytes () == data [40 :60 ]
417427
418428
419- @pytest .mark .asyncio
420429async def test_get_group_bytes_returns_none_on_failed_read () -> None :
421430 """Test _get_group_bytes returns None when ByteGetter.get returns None."""
422431 codec = ShardingCodec (chunk_shape = (8 ,))
@@ -444,7 +453,7 @@ class MockByteGetterWithIndex:
444453 call_count : int = 0
445454
446455 async def get (
447- self , prototype : BufferPrototype , byte_range : RangeByteRequest | None = None
456+ self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
448457 ) -> Buffer | None :
449458 self .call_count += 1
450459 # First call is typically for the index
@@ -457,17 +466,19 @@ async def get(
457466 return None
458467 if byte_range is None :
459468 return Buffer .from_bytes (self .chunk_data )
460- return Buffer .from_bytes (self .chunk_data [byte_range .start : byte_range .end ])
469+ # For RangeByteRequest, extract start and end
470+ start = getattr (byte_range , "start" , 0 )
471+ end = getattr (byte_range , "end" , len (self .chunk_data ))
472+ return Buffer .from_bytes (self .chunk_data [start :end ])
461473
462474
463- @pytest .mark .asyncio
464475async def test_load_partial_shard_maybe_index_load_fails () -> None :
465476 """Test _load_partial_shard_maybe returns None when index load fails."""
466477 codec = ShardingCodec (chunk_shape = (8 ,))
467478 byte_getter = MockByteGetterWithIndex (index_data = None , chunk_data = None )
468479
469480 chunks_per_shard = (2 ,)
470- all_chunk_coords = {(0 ,)}
481+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 ,)}
471482
472483 result = await codec ._load_partial_shard_maybe (
473484 byte_getter = byte_getter ,
@@ -482,7 +493,6 @@ async def test_load_partial_shard_maybe_index_load_fails() -> None:
482493 assert result is None
483494
484495
485- @pytest .mark .asyncio
486496async def test_load_partial_shard_maybe_with_empty_chunks (
487497 monkeypatch : pytest .MonkeyPatch ,
488498) -> None :
@@ -510,7 +520,7 @@ async def mock_load_index(
510520 byte_getter = MockByteGetter (data = chunk_data )
511521
512522 # Request chunks including the empty one
513- all_chunk_coords = {(0 ,), (1 ,), (2 ,)}
523+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 ,), (1 ,), (2 ,)}
514524
515525 result = await codec ._load_partial_shard_maybe (
516526 byte_getter = byte_getter ,
@@ -529,7 +539,6 @@ async def mock_load_index(
529539 assert (2 ,) in result
530540
531541
532- @pytest .mark .asyncio
533542async def test_load_partial_shard_maybe_all_chunks_empty (
534543 monkeypatch : pytest .MonkeyPatch ,
535544) -> None :
@@ -551,7 +560,7 @@ async def mock_load_index(
551560 byte_getter = MockByteGetter (data = b"" )
552561
553562 # Request some chunks - all will be empty
554- all_chunk_coords = {(0 ,), (1 ,), (2 ,)}
563+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 ,), (1 ,), (2 ,)}
555564
556565 result = await codec ._load_partial_shard_maybe (
557566 byte_getter = byte_getter ,
@@ -619,7 +628,7 @@ def test_is_total_shard_full() -> None:
619628 """Test _is_total_shard returns True when all chunk coords are present."""
620629 codec = ShardingCodec (chunk_shape = (8 ,))
621630 chunks_per_shard = (2 , 2 )
622- all_chunk_coords = {(0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 )}
631+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 )}
623632
624633 assert codec ._is_total_shard (all_chunk_coords , chunks_per_shard ) is True
625634
@@ -628,7 +637,7 @@ def test_is_total_shard_partial() -> None:
628637 """Test _is_total_shard returns False for partial chunk coords."""
629638 codec = ShardingCodec (chunk_shape = (8 ,))
630639 chunks_per_shard = (2 , 2 )
631- all_chunk_coords = {(0 , 0 ), (1 , 1 )} # Missing (0, 1) and (1, 0)
640+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 , 0 ), (1 , 1 )} # Missing (0, 1) and (1, 0)
632641
633642 assert codec ._is_total_shard (all_chunk_coords , chunks_per_shard ) is False
634643
@@ -646,10 +655,10 @@ def test_is_total_shard_1d() -> None:
646655 """Test _is_total_shard works with 1D shards."""
647656 codec = ShardingCodec (chunk_shape = (8 ,))
648657 chunks_per_shard = (4 ,)
649- all_chunk_coords = {(0 ,), (1 ,), (2 ,), (3 ,)}
658+ all_chunk_coords : set [ tuple [ int , ...]] = {(0 ,), (1 ,), (2 ,), (3 ,)}
650659
651660 assert codec ._is_total_shard (all_chunk_coords , chunks_per_shard ) is True
652661
653662 # Partial
654- partial_coords = {(0 ,), (2 ,)}
663+ partial_coords : set [ tuple [ int , ...]] = {(0 ,), (2 ,)}
655664 assert codec ._is_total_shard (partial_coords , chunks_per_shard ) is False
0 commit comments