Skip to content

Commit 646454e

Browse files
committed
Merge branch 'byterange-dataclass' into literate-byte-ranges
2 parents 8464094 + 46070f4 commit 646454e

File tree

11 files changed

+97
-81
lines changed

11 files changed

+97
-81
lines changed

src/zarr/abc/store.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from abc import ABC, abstractmethod
44
from asyncio import gather
5+
from dataclasses import dataclass
56
from itertools import starmap
6-
from typing import TYPE_CHECKING, Protocol, TypedDict, runtime_checkable
7+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
78

89
from zarr.core.buffer.core import default_buffer_prototype
910
from zarr.core.common import concurrent_map
@@ -20,21 +21,33 @@
2021
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
2122

2223

23-
class OffsetRange(TypedDict):
24+
@dataclass
25+
class ExplicitRange:
26+
"""Request a specific byte range"""
27+
28+
start: int
29+
"""The start of the byte range request (inclusive)."""
30+
end: int
31+
"""The end of the byte range request (exclusive)."""
32+
33+
34+
@dataclass
35+
class OffsetRange:
2436
"""Request all bytes starting from a given byte offset"""
2537

2638
offset: int
2739
"""The byte offset for the offset range request."""
2840

2941

30-
class SuffixRange(TypedDict):
42+
@dataclass
43+
class SuffixRange:
3144
"""Request up to the last `n` bytes"""
3245

3346
suffix: int
3447
"""The number of bytes from the suffix to request."""
3548

3649

37-
ByteRangeRequest: TypeAlias = tuple[int, int] | OffsetRange | SuffixRange
50+
ByteRangeRequest: TypeAlias = ExplicitRange | OffsetRange | SuffixRange
3851

3952

4053
class Store(ABC):

src/zarr/codecs/sharding.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
Codec,
1818
CodecPipeline,
1919
)
20-
from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter
20+
from zarr.abc.store import (
21+
ByteGetter,
22+
ByteRangeRequest,
23+
ByteSetter,
24+
ExplicitRange,
25+
SuffixRange,
26+
)
2127
from zarr.codecs.bytes import BytesCodec
2228
from zarr.codecs.crc32c_ import Crc32cCodec
2329
from zarr.core.array_spec import ArrayConfig, ArraySpec
@@ -504,7 +510,8 @@ async def _decode_partial_single(
504510
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
505511
if chunk_byte_slice:
506512
chunk_bytes = await byte_getter.get(
507-
prototype=chunk_spec.prototype, byte_range=chunk_byte_slice
513+
prototype=chunk_spec.prototype,
514+
byte_range=ExplicitRange(chunk_byte_slice[0], chunk_byte_slice[1]),
508515
)
509516
if chunk_bytes:
510517
shard_dict[chunk_coords] = chunk_bytes
@@ -696,11 +703,11 @@ async def _load_shard_index_maybe(
696703
shard_index_size = self._shard_index_size(chunks_per_shard)
697704
if self.index_location == ShardingCodecIndexLocation.start:
698705
index_bytes = await byte_getter.get(
699-
prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size)
706+
prototype=numpy_buffer_prototype(), byte_range=ExplicitRange(0, shard_index_size)
700707
)
701708
else:
702709
index_bytes = await byte_getter.get(
703-
prototype=numpy_buffer_prototype(), byte_range={"suffix": shard_index_size}
710+
prototype=numpy_buffer_prototype(), byte_range=SuffixRange(shard_index_size)
704711
)
705712
if index_bytes is not None:
706713
return await self._decode_shard_index(index_bytes, chunks_per_shard)

src/zarr/storage/_fsspec.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from typing import TYPE_CHECKING, Any
55

6-
from zarr.abc.store import ByteRangeRequest, Store
6+
from zarr.abc.store import ByteRangeRequest, ExplicitRange, OffsetRange, Store, SuffixRange
77
from zarr.storage._common import _dereference_path
88

99
if TYPE_CHECKING:
@@ -209,21 +209,22 @@ async def get(
209209
try:
210210
if byte_range is None:
211211
value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
212-
elif isinstance(byte_range, tuple):
212+
elif isinstance(byte_range, ExplicitRange):
213213
value = prototype.buffer.from_bytes(
214-
await self.fs._cat_file(path, start=byte_range[0], end=byte_range[1])
215-
)
216-
elif isinstance(byte_range, dict):
217-
if "offset" in byte_range:
218-
value = prototype.buffer.from_bytes(
219-
await self.fs._cat_file(path, start=byte_range["offset"], end=None) # type: ignore[typeddict-item]
220-
)
221-
elif "suffix" in byte_range:
222-
value = prototype.buffer.from_bytes(
223-
await self.fs._cat_file(path, start=-byte_range["suffix"], end=None)
214+
await self.fs._cat_file(
215+
path,
216+
start=byte_range.start,
217+
end=byte_range.end,
224218
)
225-
else:
226-
raise ValueError("Invalid format for ByteRangeRequest")
219+
)
220+
elif isinstance(byte_range, OffsetRange):
221+
value = prototype.buffer.from_bytes(
222+
await self.fs._cat_file(path, start=byte_range.offset, end=None)
223+
)
224+
elif isinstance(byte_range, SuffixRange):
225+
value = prototype.buffer.from_bytes(
226+
await self.fs._cat_file(path, start=-byte_range.suffix, end=None)
227+
)
227228
else:
228229
raise ValueError("Invalid format for ByteRangeRequest")
229230
except self.allowed_exceptions:
@@ -286,18 +287,15 @@ async def get_partial_values(
286287
if byte_range is None:
287288
starts.append(None)
288289
stops.append(None)
289-
elif isinstance(byte_range, tuple):
290-
starts.append(byte_range[0])
291-
stops.append(byte_range[1])
292-
elif isinstance(byte_range, dict):
293-
if "offset" in byte_range:
294-
starts.append(byte_range["offset"]) # type: ignore[typeddict-item]
295-
stops.append(None)
296-
elif "suffix" in byte_range:
297-
starts.append(-byte_range["suffix"])
298-
stops.append(None)
299-
else:
300-
raise ValueError("Invalid format for ByteRangeRequest")
290+
elif isinstance(byte_range, ExplicitRange):
291+
starts.append(byte_range.start)
292+
stops.append(byte_range.end)
293+
elif isinstance(byte_range, OffsetRange):
294+
starts.append(byte_range.offset)
295+
stops.append(None)
296+
elif isinstance(byte_range, SuffixRange):
297+
starts.append(-byte_range.suffix)
298+
stops.append(None)
301299
else:
302300
raise ValueError("Invalid format for ByteRangeRequest")
303301
else:

src/zarr/storage/_local.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING
99

10-
from zarr.abc.store import ByteRangeRequest, Store
10+
from zarr.abc.store import ByteRangeRequest, ExplicitRange, OffsetRange, Store, SuffixRange
1111
from zarr.core.buffer import Buffer
1212
from zarr.core.buffer.core import default_buffer_prototype
1313
from zarr.core.common import concurrent_map
@@ -23,20 +23,16 @@ def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRangeRequest |
2323
return prototype.buffer.from_bytes(path.read_bytes())
2424
with path.open("rb") as f:
2525
size = f.seek(0, io.SEEK_END)
26-
if isinstance(byte_range, tuple):
27-
start, end = byte_range
28-
f.seek(start)
29-
return prototype.buffer.from_bytes(f.read(end - f.tell()))
30-
elif isinstance(byte_range, dict):
31-
if "offset" in byte_range:
32-
f.seek(byte_range["offset"]) # type: ignore[typeddict-item]
33-
elif "suffix" in byte_range:
34-
f.seek(max(0, size - byte_range["suffix"]))
35-
else:
36-
raise TypeError("Invalid format for ByteRangeRequest")
37-
return prototype.buffer.from_bytes(f.read())
26+
if isinstance(byte_range, ExplicitRange):
27+
f.seek(byte_range.start)
28+
return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell()))
29+
elif isinstance(byte_range, OffsetRange):
30+
f.seek(byte_range.offset)
31+
elif isinstance(byte_range, SuffixRange):
32+
f.seek(max(0, size - byte_range.suffix))
3833
else:
3934
raise TypeError("Invalid format for ByteRangeRequest")
35+
return prototype.buffer.from_bytes(f.read())
4036

4137

4238
def _put(

src/zarr/storage/_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def get_partial_values(
9696
# docstring inherited
9797

9898
# All the key-ranges arguments goes with the same prototype
99-
async def _get(key: str, byte_range: ByteRangeRequest) -> Buffer | None:
99+
async def _get(key: str, byte_range: ByteRangeRequest | None) -> Buffer | None:
100100
return await self.get(key, prototype=prototype, byte_range=byte_range)
101101

102102
return await concurrent_map(key_ranges, _get, limit=None)

src/zarr/storage/_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING
66

7+
from zarr.abc.store import ExplicitRange, OffsetRange, SuffixRange
8+
79
if TYPE_CHECKING:
810
from zarr.abc.store import ByteRangeRequest
911
from zarr.core.buffer import Buffer
@@ -54,14 +56,13 @@ def _normalize_byte_range_index(
5456
if byte_range is None:
5557
start = 0
5658
stop = len(data) + 1
57-
elif isinstance(byte_range, tuple):
58-
start = byte_range[0]
59-
stop = byte_range[1]
60-
elif "offset" in byte_range:
61-
# See https://github.com/python/mypy/issues/17087 for typeddict-item ignore explanation
62-
start = byte_range["offset"] # type: ignore[typeddict-item]
59+
elif isinstance(byte_range, ExplicitRange):
60+
start = byte_range.start
61+
stop = byte_range.end
62+
elif isinstance(byte_range, OffsetRange):
63+
start = byte_range.offset
6364
stop = len(data) + 1
64-
elif "suffix" in byte_range:
65-
start = len(data) - byte_range["suffix"]
65+
elif isinstance(byte_range, SuffixRange):
66+
start = len(data) - byte_range.suffix
6667
stop = len(data) + 1
6768
return (start, stop)

src/zarr/storage/_zip.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Literal
99

10-
from zarr.abc.store import ByteRangeRequest, Store
10+
from zarr.abc.store import ByteRangeRequest, ExplicitRange, OffsetRange, Store, SuffixRange
1111
from zarr.core.buffer import Buffer, BufferPrototype
1212

1313
if TYPE_CHECKING:
@@ -145,21 +145,17 @@ def _get(
145145
with self._zf.open(key) as f: # will raise KeyError
146146
if byte_range is None:
147147
return prototype.buffer.from_bytes(f.read())
148-
if isinstance(byte_range, tuple):
149-
start, end = byte_range
150-
f.seek(start)
151-
return prototype.buffer.from_bytes(f.read(end - f.tell()))
152-
elif isinstance(byte_range, dict):
153-
size = f.seek(0, os.SEEK_END)
154-
if "offset" in byte_range:
155-
f.seek(byte_range["offset"]) # type: ignore[typeddict-item]
156-
elif "suffix" in byte_range:
157-
f.seek(max(0, size - byte_range["suffix"]))
158-
else:
159-
raise TypeError("Invalid format for ByteRangeRequest")
160-
return prototype.buffer.from_bytes(f.read())
148+
elif isinstance(byte_range, ExplicitRange):
149+
f.seek(byte_range.start)
150+
return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell()))
151+
size = f.seek(0, os.SEEK_END)
152+
if isinstance(byte_range, OffsetRange):
153+
f.seek(byte_range.offset)
154+
elif isinstance(byte_range, SuffixRange):
155+
f.seek(max(0, size - byte_range.suffix))
161156
else:
162157
raise TypeError("Invalid format for ByteRangeRequest")
158+
return prototype.buffer.from_bytes(f.read())
163159
except KeyError:
164160
return None
165161

src/zarr/testing/stateful.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ def get_partial_values(self, data: DataObject) -> None:
355355
model_vals_ls = []
356356

357357
for key, byte_range in key_range:
358-
start = byte_range[0] or 0
359-
stop = byte_range[1]
358+
start = byte_range.start
359+
stop = byte_range.end
360360
model_vals_ls.append(self.model[key][start:stop])
361361

362362
assert all(

src/zarr/testing/store.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest
1616

17-
from zarr.abc.store import ByteRangeRequest, Store
17+
from zarr.abc.store import ByteRangeRequest, ExplicitRange, OffsetRange, Store, SuffixRange
1818
from zarr.core.buffer import Buffer, default_buffer_prototype
1919
from zarr.core.sync import _collect_aiterator
2020
from zarr.storage._utils import _normalize_byte_range_index
@@ -115,7 +115,9 @@ def test_store_supports_listing(self, store: S) -> None:
115115

116116
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
117117
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
118-
@pytest.mark.parametrize("byte_range", [None, (1, 3), {"offset": 1}, {"suffix": 1}])
118+
@pytest.mark.parametrize(
119+
"byte_range", [None, ExplicitRange(1, 4), OffsetRange(1), SuffixRange(1)]
120+
)
119121
async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRangeRequest) -> None:
120122
"""
121123
Ensure that data can be read from the store using the store.get method.
@@ -177,9 +179,9 @@ async def test_set_many(self, store: S) -> None:
177179
"key_ranges",
178180
[
179181
[],
180-
[("zarr.json", (0, 1))],
181-
[("c/0", (0, 1)), ("zarr.json", None)],
182-
[("c/0/0", (0, 1)), ("c/0/1", {"suffix": 2}), ("c/0/2", {"offset": 2})],
182+
[("zarr.json", ExplicitRange(0, 2))],
183+
[("c/0", ExplicitRange(0, 2)), ("zarr.json", None)],
184+
[("c/0/0", ExplicitRange(0, 2)), ("c/0/1", SuffixRange(2)), ("c/0/2", OffsetRange(2))],
183185
],
184186
)
185187
async def test_get_partial_values(

src/zarr/testing/strategies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from hypothesis.strategies import SearchStrategy
88

99
import zarr
10+
from zarr.abc.store import ExplicitRange
1011
from zarr.core.array import Array
1112
from zarr.core.common import ZarrFormat
1213
from zarr.core.sync import sync
@@ -197,9 +198,10 @@ def key_ranges(
197198
[(key, (range_start, range_end)),
198199
(key, (range_start, range_end)),...]
199200
"""
200-
byte_ranges = st.tuples(
201-
st.integers(min_value=0, max_value=max_size),
202-
st.integers(min_value=0, max_value=max_size),
201+
byte_ranges = st.builds(
202+
ExplicitRange,
203+
start=st.integers(min_value=0, max_value=max_size),
204+
end=st.integers(min_value=0, max_value=max_size),
203205
)
204206
key_tuple = st.tuples(keys, byte_ranges)
205207
return st.lists(key_tuple, min_size=1, max_size=10)

0 commit comments

Comments
 (0)