Skip to content

Commit 125a729

Browse files
committed
Use TypedDicts for more literate ByteRangeRequests
1 parent fb11810 commit 125a729

File tree

11 files changed

+102
-87
lines changed

11 files changed

+102
-87
lines changed

src/zarr/abc/store.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from asyncio import gather
55
from itertools import starmap
6-
from typing import TYPE_CHECKING, Protocol, runtime_checkable
6+
from typing import TYPE_CHECKING, Protocol, TypedDict, runtime_checkable
77

88
from zarr.core.buffer.core import default_buffer_prototype
99
from zarr.core.common import concurrent_map
@@ -19,7 +19,22 @@
1919

2020
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
2121

22-
ByteRangeRequest: TypeAlias = tuple[int | None, int | None]
22+
23+
class OffsetRange(TypedDict):
24+
"""Request all bytes starting from a given byte offset"""
25+
26+
offset: int
27+
"""The byte offset for the offset range request."""
28+
29+
30+
class SuffixRange(TypedDict):
31+
"""Request up to the last `n` bytes"""
32+
33+
suffix: int
34+
"""The number of bytes from the suffix to request."""
35+
36+
37+
ByteRangeRequest: TypeAlias = None | tuple[int, int] | OffsetRange | SuffixRange
2338

2439

2540
class Store(ABC):
@@ -148,7 +163,13 @@ async def get(
148163
Parameters
149164
----------
150165
key : str
151-
byte_range : tuple[int | None, int | None], optional
166+
byte_range : ByteRangeRequest, optional
167+
168+
The semantics of this argument are:
169+
170+
- tuple (int, int): Request a specific range of bytes (start, end). The end offset is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned.
171+
- {"offset": int}: Request all bytes starting from a given byte offset. This is equivalent to bytes={int}- as an HTTP header.
172+
- {"suffix": int}: Request the last int bytes. Note that here, int is the size of the request, not the byte offset. This is equivalent to bytes=-{int} as an HTTP header.
152173
153174
Returns
154175
-------

src/zarr/storage/_utils.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING
66

77
if TYPE_CHECKING:
8+
from zarr.abc.store import ByteRangeRequest
89
from zarr.core.buffer import Buffer
910

1011

@@ -44,25 +45,19 @@ def normalize_path(path: str | bytes | Path | None) -> str:
4445
return result
4546

4647

47-
def _normalize_interval_index(
48-
data: Buffer, interval: tuple[int | None, int | None] | None
49-
) -> tuple[int, int]:
48+
def _normalize_byte_range_index(data: Buffer, byte_range: ByteRangeRequest) -> tuple[int, int]:
5049
"""
51-
Convert an implicit interval into an explicit start and length
50+
Convert an ByteRangeRequest into an explicit start and length
5251
"""
53-
if interval is None:
52+
if byte_range is None:
5453
start = 0
5554
length = len(data)
56-
else:
57-
maybe_start, maybe_len = interval
58-
if maybe_start is None:
59-
start = 0
60-
else:
61-
start = maybe_start
62-
63-
if maybe_len is None:
64-
length = len(data) - start
65-
else:
66-
length = maybe_len
67-
55+
elif isinstance(byte_range, tuple):
56+
start = byte_range[0]
57+
length = byte_range[1] - start
58+
elif start := byte_range.get("offset"):
59+
length = len(data) - start
60+
elif suffix := byte_range.get("suffix"):
61+
start = len(data) - suffix
62+
length = len(data) - start
6863
return (start, length)

src/zarr/storage/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def open(
102102
async def get(
103103
self,
104104
prototype: BufferPrototype | None = None,
105-
byte_range: ByteRangeRequest | None = None,
105+
byte_range: ByteRangeRequest = None,
106106
) -> Buffer | None:
107107
"""
108108
Read bytes from the store.

src/zarr/storage/fsspec.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -199,31 +199,33 @@ async def get(
199199
self,
200200
key: str,
201201
prototype: BufferPrototype,
202-
byte_range: ByteRangeRequest | None = None,
202+
byte_range: ByteRangeRequest = None,
203203
) -> Buffer | None:
204204
# docstring inherited
205205
if not self._is_open:
206206
await self._open()
207207
path = _dereference_path(self.path, key)
208208

209209
try:
210-
if byte_range:
211-
# fsspec uses start/end, not start/length
212-
start, length = byte_range
213-
if start is not None and length is not None:
214-
end = start + length
215-
elif length is not None:
216-
end = length
217-
else:
218-
end = None
219-
value = prototype.buffer.from_bytes(
220-
await (
221-
self.fs._cat_file(path, start=byte_range[0], end=end)
222-
if byte_range
223-
else self.fs._cat_file(path)
210+
if byte_range is None:
211+
value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
212+
elif isinstance(byte_range, tuple):
213+
value = prototype.buffer.from_bytes(
214+
await self.fs._cat_file(path, start=byte_range[0], end=byte_range[1])
224215
)
225-
)
226-
216+
elif isinstance(byte_range, dict):
217+
if "offset" in byte_range:
218+
value = prototype.buffer.from_bytes(
219+
await self.fs._cat_file(path, start=byte_range["offset"], end=None)
220+
)
221+
elif "suffix" in byte_range:
222+
value = prototype.buffer.from_bytes(
223+
await self.fs._cat_file(path, start=-byte_range["suffix"], end=None)
224+
)
225+
else:
226+
raise ValueError("Invalid format for ByteRangeRequest")
227+
else:
228+
raise ValueError("Invalid format for ByteRangeRequest")
227229
except self.allowed_exceptions:
228230
return None
229231
except OSError as e:

src/zarr/storage/local.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,25 @@
1818
from zarr.core.buffer import BufferPrototype
1919

2020

21-
def _get(
22-
path: Path, prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None
23-
) -> Buffer:
24-
if byte_range is not None:
25-
if byte_range[0] is None:
26-
start = 0
27-
else:
28-
start = byte_range[0]
29-
30-
end = (start + byte_range[1]) if byte_range[1] is not None else None
31-
else:
21+
def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRangeRequest | None) -> Buffer:
22+
if byte_range is None:
3223
return prototype.buffer.from_bytes(path.read_bytes())
3324
with path.open("rb") as f:
3425
size = f.seek(0, io.SEEK_END)
35-
if start is not None:
36-
if start >= 0:
37-
f.seek(start)
38-
else:
39-
f.seek(max(0, size + start))
40-
if end is not None:
41-
if end < 0:
42-
end = size + end
26+
if isinstance(byte_range, tuple):
27+
start, end = byte_range
28+
f.seek(start)
4329
return prototype.buffer.from_bytes(f.read(end - f.tell()))
44-
return prototype.buffer.from_bytes(f.read())
30+
elif isinstance(byte_range, dict):
31+
if "offset" in byte_range:
32+
f.seek(byte_range["offset"])
33+
elif "suffix" in byte_range:
34+
f.seek(max(0, size - byte_range["suffix"]))
35+
else:
36+
raise TypeError("Invalid format for ByteRangeRequest")
37+
return prototype.buffer.from_bytes(f.read())
38+
else:
39+
raise TypeError("Invalid format for ByteRangeRequest")
4540

4641

4742
def _put(
@@ -127,7 +122,7 @@ async def get(
127122
self,
128123
key: str,
129124
prototype: BufferPrototype | None = None,
130-
byte_range: tuple[int | None, int | None] | None = None,
125+
byte_range: ByteRangeRequest = None,
131126
) -> Buffer | None:
132127
# docstring inherited
133128
if prototype is None:

src/zarr/storage/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def get(
161161
self,
162162
key: str,
163163
prototype: BufferPrototype,
164-
byte_range: tuple[int | None, int | None] | None = None,
164+
byte_range: ByteRangeRequest = None,
165165
) -> Buffer | None:
166166
# docstring inherited
167167
with self.log(key):

src/zarr/storage/memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from zarr.abc.store import ByteRangeRequest, Store
77
from zarr.core.buffer import Buffer, gpu
88
from zarr.core.common import concurrent_map
9-
from zarr.storage._utils import _normalize_interval_index
9+
from zarr.storage._utils import _normalize_byte_range_index
1010

1111
if TYPE_CHECKING:
1212
from collections.abc import AsyncIterator, Iterable, MutableMapping
@@ -75,15 +75,15 @@ async def get(
7575
self,
7676
key: str,
7777
prototype: BufferPrototype,
78-
byte_range: tuple[int | None, int | None] | None = None,
78+
byte_range: ByteRangeRequest = None,
7979
) -> Buffer | None:
8080
# docstring inherited
8181
if not self._is_open:
8282
await self._open()
8383
assert isinstance(key, str)
8484
try:
8585
value = self._store_dict[key]
86-
start, length = _normalize_interval_index(value, byte_range)
86+
start, length = _normalize_byte_range_index(value, byte_range)
8787
return prototype.buffer.from_buffer(value[start : start + length])
8888
except KeyError:
8989
return None

src/zarr/storage/zip.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,28 @@ def _get(
138138
self,
139139
key: str,
140140
prototype: BufferPrototype,
141-
byte_range: ByteRangeRequest | None = None,
141+
byte_range: ByteRangeRequest = None,
142142
) -> Buffer | None:
143143
# docstring inherited
144144
try:
145145
with self._zf.open(key) as f: # will raise KeyError
146146
if byte_range is None:
147147
return prototype.buffer.from_bytes(f.read())
148-
start, length = byte_range
149-
if start:
150-
if start < 0:
151-
start = f.seek(start, os.SEEK_END) + start
148+
if isinstance(byte_range, tuple):
149+
start, end = byte_range
150+
f.seek(start)
151+
return prototype.buffer.from_bytes(f.read(end - f.tell()))
152+
elif isinstance(byte_range, dict):
153+
size = f.seek(0, os.SEEK_END)
154+
if "offset" in byte_range:
155+
f.seek(byte_range["offset"])
156+
elif "suffix" in byte_range:
157+
f.seek(max(0, size - byte_range["suffix"]))
152158
else:
153-
start = f.seek(start, os.SEEK_SET)
154-
if length:
155-
return prototype.buffer.from_bytes(f.read(length))
156-
else:
159+
raise TypeError("Invalid format for ByteRangeRequest")
157160
return prototype.buffer.from_bytes(f.read())
161+
else:
162+
raise TypeError("Invalid format for ByteRangeRequest")
158163
except KeyError:
159164
return None
160165

src/zarr/testing/stateful.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,7 @@ def get_partial_values(self, data: DataObject) -> None:
356356

357357
for key, byte_range in key_range:
358358
start = byte_range[0] or 0
359-
step = byte_range[1]
360-
stop = start + step if step is not None else None
359+
stop = byte_range[1]
361360
model_vals_ls.append(self.model[key][start:stop])
362361

363362
assert all(

src/zarr/testing/store.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from zarr.abc.store import ByteRangeRequest, Store
1818
from zarr.core.buffer import Buffer, default_buffer_prototype
1919
from zarr.core.sync import _collect_aiterator
20-
from zarr.storage._utils import _normalize_interval_index
20+
from zarr.storage._utils import _normalize_byte_range_index
2121
from zarr.testing.utils import assert_bytes_equal
2222

2323
__all__ = ["StoreTests"]
@@ -115,17 +115,15 @@ def test_store_supports_listing(self, store: S) -> None:
115115

116116
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
117117
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
118-
@pytest.mark.parametrize("byte_range", [None, (0, None), (1, None), (1, 2), (None, 1)])
119-
async def test_get(
120-
self, store: S, key: str, data: bytes, byte_range: tuple[int | None, int | None] | None
121-
) -> None:
118+
@pytest.mark.parametrize("byte_range", [None, (1, 3), {"offset": 1}, {"suffix": 1}])
119+
async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRangeRequest) -> None:
122120
"""
123121
Ensure that data can be read from the store using the store.get method.
124122
"""
125123
data_buf = self.buffer_cls.from_bytes(data)
126124
await self.set(store, key, data_buf)
127125
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
128-
start, length = _normalize_interval_index(data_buf, interval=byte_range)
126+
start, length = _normalize_byte_range_index(data_buf, byte_range=byte_range)
129127
expected = data_buf[start : start + length]
130128
assert_bytes_equal(observed, expected)
131129

@@ -180,12 +178,12 @@ async def test_set_many(self, store: S) -> None:
180178
[
181179
[],
182180
[("zarr.json", (0, 1))],
183-
[("c/0", (0, 1)), ("zarr.json", (0, None))],
184-
[("c/0/0", (0, 1)), ("c/0/1", (None, 2)), ("c/0/2", (0, 3))],
181+
[("c/0", (0, 1)), ("zarr.json", None)],
182+
[("c/0/0", (0, 1)), ("c/0/1", {"suffix": 2}), ("c/0/2", {"offset": 2})],
185183
],
186184
)
187185
async def test_get_partial_values(
188-
self, store: S, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
186+
self, store: S, key_ranges: list[tuple[str, ByteRangeRequest]]
189187
) -> None:
190188
# put all of the data
191189
for key, _ in key_ranges:

0 commit comments

Comments
 (0)