Skip to content

Commit 6437fb6

Browse files
committed
Fix typing errors in test_sharing_unit.py
1 parent 24f6f1c commit 6437fb6

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

tests/test_codecs/test_sharding_unit.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from zarr.core.buffer.cpu import Buffer
1818

1919
if TYPE_CHECKING:
20-
from zarr.abc.store import RangeByteRequest
20+
from zarr.abc.store import ByteRequest
2121
from zarr.core.buffer import BufferPrototype
2222

2323

@@ -349,16 +349,18 @@ class MockByteGetter:
349349
return_none: bool = False
350350

351351
async def get(
352-
self, prototype: BufferPrototype, byte_range: RangeByteRequest | None = None
352+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
353353
) -> Buffer | None:
354354
if self.return_none:
355355
return None
356356
if byte_range is None:
357357
return Buffer.from_bytes(self.data)
358-
return Buffer.from_bytes(self.data[byte_range.start : byte_range.end])
358+
# For RangeByteRequest, extract start and end
359+
start = getattr(byte_range, "start", 0)
360+
end = getattr(byte_range, "end", len(self.data))
361+
return Buffer.from_bytes(self.data[start:end])
359362

360363

361-
@pytest.mark.asyncio
362364
async def test_get_group_bytes_single_chunk() -> None:
363365
"""Test _get_group_bytes extracts single chunk correctly."""
364366
codec = ShardingCodec(chunk_shape=(8,))
@@ -372,10 +374,11 @@ async def test_get_group_bytes_single_chunk() -> None:
372374

373375
assert result is not None
374376
assert (0,) in result
375-
assert result[(0,)].as_numpy_array().tobytes() == data[10:30]
377+
chunk_buf = result[(0,)]
378+
assert chunk_buf is not None
379+
assert chunk_buf.as_numpy_array().tobytes() == data[10:30]
376380

377381

378-
@pytest.mark.asyncio
379382
async def test_get_group_bytes_multiple_chunks() -> None:
380383
"""Test _get_group_bytes extracts multiple chunks with correct offsets."""
381384
codec = ShardingCodec(chunk_shape=(8,))
@@ -391,11 +394,14 @@ async def test_get_group_bytes_multiple_chunks() -> None:
391394

392395
assert result is not None
393396
assert len(result) == 2
394-
assert result[(0,)].as_numpy_array().tobytes() == data[10:30]
395-
assert result[(1,)].as_numpy_array().tobytes() == data[30:50]
397+
chunk0_buf = result[(0,)]
398+
chunk1_buf = result[(1,)]
399+
assert chunk0_buf is not None
400+
assert chunk1_buf is not None
401+
assert chunk0_buf.as_numpy_array().tobytes() == data[10:30]
402+
assert chunk1_buf.as_numpy_array().tobytes() == data[30:50]
396403

397404

398-
@pytest.mark.asyncio
399405
async def test_get_group_bytes_with_gap() -> None:
400406
"""Test _get_group_bytes handles chunks with gaps correctly."""
401407
codec = ShardingCodec(chunk_shape=(8,))
@@ -412,11 +418,14 @@ async def test_get_group_bytes_with_gap() -> None:
412418
assert result is not None
413419
assert len(result) == 2
414420
# The byte_getter.get is called with range [10, 60), then sliced
415-
assert result[(0,)].as_numpy_array().tobytes() == data[10:20]
416-
assert result[(1,)].as_numpy_array().tobytes() == data[40:60]
421+
chunk0_buf = result[(0,)]
422+
chunk1_buf = result[(1,)]
423+
assert chunk0_buf is not None
424+
assert chunk1_buf is not None
425+
assert chunk0_buf.as_numpy_array().tobytes() == data[10:20]
426+
assert chunk1_buf.as_numpy_array().tobytes() == data[40:60]
417427

418428

419-
@pytest.mark.asyncio
420429
async def test_get_group_bytes_returns_none_on_failed_read() -> None:
421430
"""Test _get_group_bytes returns None when ByteGetter.get returns None."""
422431
codec = ShardingCodec(chunk_shape=(8,))
@@ -444,7 +453,7 @@ class MockByteGetterWithIndex:
444453
call_count: int = 0
445454

446455
async def get(
447-
self, prototype: BufferPrototype, byte_range: RangeByteRequest | None = None
456+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
448457
) -> Buffer | None:
449458
self.call_count += 1
450459
# First call is typically for the index
@@ -457,17 +466,19 @@ async def get(
457466
return None
458467
if byte_range is None:
459468
return Buffer.from_bytes(self.chunk_data)
460-
return Buffer.from_bytes(self.chunk_data[byte_range.start : byte_range.end])
469+
# For RangeByteRequest, extract start and end
470+
start = getattr(byte_range, "start", 0)
471+
end = getattr(byte_range, "end", len(self.chunk_data))
472+
return Buffer.from_bytes(self.chunk_data[start:end])
461473

462474

463-
@pytest.mark.asyncio
464475
async def test_load_partial_shard_maybe_index_load_fails() -> None:
465476
"""Test _load_partial_shard_maybe returns None when index load fails."""
466477
codec = ShardingCodec(chunk_shape=(8,))
467478
byte_getter = MockByteGetterWithIndex(index_data=None, chunk_data=None)
468479

469480
chunks_per_shard = (2,)
470-
all_chunk_coords = {(0,)}
481+
all_chunk_coords: set[tuple[int, ...]] = {(0,)}
471482

472483
result = await codec._load_partial_shard_maybe(
473484
byte_getter=byte_getter,
@@ -482,7 +493,6 @@ async def test_load_partial_shard_maybe_index_load_fails() -> None:
482493
assert result is None
483494

484495

485-
@pytest.mark.asyncio
486496
async def test_load_partial_shard_maybe_with_empty_chunks(
487497
monkeypatch: pytest.MonkeyPatch,
488498
) -> None:
@@ -510,7 +520,7 @@ async def mock_load_index(
510520
byte_getter = MockByteGetter(data=chunk_data)
511521

512522
# Request chunks including the empty one
513-
all_chunk_coords = {(0,), (1,), (2,)}
523+
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
514524

515525
result = await codec._load_partial_shard_maybe(
516526
byte_getter=byte_getter,
@@ -529,7 +539,6 @@ async def mock_load_index(
529539
assert (2,) in result
530540

531541

532-
@pytest.mark.asyncio
533542
async def test_load_partial_shard_maybe_all_chunks_empty(
534543
monkeypatch: pytest.MonkeyPatch,
535544
) -> None:
@@ -551,7 +560,7 @@ async def mock_load_index(
551560
byte_getter = MockByteGetter(data=b"")
552561

553562
# Request some chunks - all will be empty
554-
all_chunk_coords = {(0,), (1,), (2,)}
563+
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
555564

556565
result = await codec._load_partial_shard_maybe(
557566
byte_getter=byte_getter,
@@ -619,7 +628,7 @@ def test_is_total_shard_full() -> None:
619628
"""Test _is_total_shard returns True when all chunk coords are present."""
620629
codec = ShardingCodec(chunk_shape=(8,))
621630
chunks_per_shard = (2, 2)
622-
all_chunk_coords = {(0, 0), (0, 1), (1, 0), (1, 1)}
631+
all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (0, 1), (1, 0), (1, 1)}
623632

624633
assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True
625634

@@ -628,7 +637,7 @@ def test_is_total_shard_partial() -> None:
628637
"""Test _is_total_shard returns False for partial chunk coords."""
629638
codec = ShardingCodec(chunk_shape=(8,))
630639
chunks_per_shard = (2, 2)
631-
all_chunk_coords = {(0, 0), (1, 1)} # Missing (0, 1) and (1, 0)
640+
all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (1, 1)} # Missing (0, 1) and (1, 0)
632641

633642
assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is False
634643

@@ -646,10 +655,10 @@ def test_is_total_shard_1d() -> None:
646655
"""Test _is_total_shard works with 1D shards."""
647656
codec = ShardingCodec(chunk_shape=(8,))
648657
chunks_per_shard = (4,)
649-
all_chunk_coords = {(0,), (1,), (2,), (3,)}
658+
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,), (3,)}
650659

651660
assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True
652661

653662
# Partial
654-
partial_coords = {(0,), (2,)}
663+
partial_coords: set[tuple[int, ...]] = {(0,), (2,)}
655664
assert codec._is_total_shard(partial_coords, chunks_per_shard) is False

0 commit comments

Comments
 (0)