Skip to content

Commit b976450

Browse files
committed
Fixes to _get_partial_values
1 parent 619df43 commit b976450

File tree

1 file changed

+156
-30
lines changed

1 file changed

+156
-30
lines changed

src/zarr/storage/object_store.py

Lines changed: 156 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
from collections import defaultdict
55
from collections.abc import Iterable
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, TypedDict
77

88
import obstore as obs
99

@@ -15,8 +15,7 @@
1515
from collections.abc import AsyncGenerator, Coroutine, Iterable
1616
from typing import Any
1717

18-
from obstore import Buffer as ObjectStoreBuffer
19-
from obstore import ListStream, ObjectMeta
18+
from obstore import ListStream, ObjectMeta, OffsetRange, SuffixRange
2019
from obstore.store import ObjectStore as _ObjectStore
2120

2221
from zarr.core.buffer import Buffer, BufferPrototype
@@ -62,33 +61,7 @@ async def get_partial_values(
6261
prototype: BufferPrototype,
6362
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
6463
) -> list[Buffer | None]:
65-
# TODO: this is a bit hacky and untested. ObjectStore has a `get_ranges` method
66-
# that will additionally merge nearby ranges, but it's _per_ file. So we need to
67-
# split these key_ranges into **per-file** key ranges, and then reassemble the
68-
# results in the original order.
69-
key_ranges = list(key_ranges)
70-
71-
per_file_requests: dict[str, list[tuple[int | None, int | None, int]]] = defaultdict(list)
72-
for idx, (path, range_) in enumerate(key_ranges):
73-
per_file_requests[path].append((range_[0], range_[1], idx))
74-
75-
futs: list[Coroutine[Any, Any, list[ObjectStoreBuffer]]] = []
76-
for path, ranges in per_file_requests.items():
77-
starts = [r[0] for r in ranges]
78-
ends = [r[1] for r in ranges]
79-
fut = obs.get_ranges_async(self.store, path, starts=starts, ends=ends)
80-
futs.append(fut)
81-
82-
result = await asyncio.gather(*futs)
83-
84-
output_buffers: list[type[BufferPrototype]] = [b""] * len(key_ranges)
85-
for per_file_request, buffers in zip(per_file_requests.items(), result, strict=True):
86-
path, ranges = per_file_request
87-
for buffer, ranges_ in zip(buffers, ranges, strict=True):
88-
initial_index = ranges_[2]
89-
output_buffers[initial_index] = prototype.buffer.from_buffer(memoryview(buffer))
90-
91-
return output_buffers
64+
return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges)
9265

9366
async def exists(self, key: str) -> bool:
9467
try:
@@ -163,3 +136,156 @@ async def _transform_list_dir(
163136
# Yield this item if "/" does not exist after the prefix.
164137
if "/" not in item["path"][prefix_len:]:
165138
yield item["path"]
139+
140+
141+
class BoundedRequest(TypedDict):
142+
"""Range request with a known start and end byte.
143+
144+
These requests can be multiplexed natively on the Rust side with
145+
`obstore.get_ranges_async`.
146+
"""
147+
148+
original_request_index: int
149+
"""The positional index in the original key_ranges input"""
150+
151+
start: int
152+
"""Start byte offset."""
153+
154+
end: int
155+
"""End byte offset."""
156+
157+
158+
class OtherRequest(TypedDict):
159+
"""Offset or suffix range requests.
160+
161+
These requests cannot be concurrent on the Rust side, and each need their own call
162+
to `obstore.get_async`, passing in the `range` parameter.
163+
"""
164+
165+
original_request_index: int
166+
"""The positional index in the original key_ranges input"""
167+
168+
path: str
169+
"""The path to request from."""
170+
171+
range: OffsetRange | SuffixRange
172+
"""The range request type."""
173+
174+
175+
class Response(TypedDict):
176+
"""A response buffer associated with the original index that it should be restored to."""
177+
178+
original_request_index: int
179+
"""The positional index in the original key_ranges input"""
180+
181+
buffer: Buffer
182+
"""The buffer returned from obstore's range request."""
183+
184+
185+
async def _make_bounded_requests(
186+
store: obs.store.ObjectStore,
187+
path: str,
188+
requests: list[BoundedRequest],
189+
prototype: BufferPrototype,
190+
) -> list[Response]:
191+
"""Make all bounded requests for a specific file.
192+
193+
`obstore.get_ranges_async` allows for making concurrent requests for multiple ranges
194+
within a single file, and will e.g. merge concurrent requests. This only uses one
195+
single Python coroutine.
196+
"""
197+
198+
starts = [r["start"] for r in requests]
199+
ends = [r["end"] for r in requests]
200+
responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends)
201+
202+
buffer_responses: list[Response] = []
203+
for request, response in zip(requests, responses, strict=True):
204+
buffer_responses.append(
205+
{
206+
"original_request_index": request["original_request_index"],
207+
"buffer": prototype.buffer.from_bytes(memoryview(response)),
208+
}
209+
)
210+
211+
return buffer_responses
212+
213+
214+
async def _make_other_request(
215+
store: obs.store.ObjectStore,
216+
request: OtherRequest,
217+
prototype: BufferPrototype,
218+
) -> list[Response]:
219+
"""Make suffix or offset requests.
220+
221+
We return a `list[Response]` for symmetry with `_make_bounded_requests` so that all
222+
futures can be gathered together.
223+
"""
224+
resp = await obs.get_async(store, request["path"], options={"range": request["range"]})
225+
buffer = await resp.bytes_async()
226+
return [
227+
{
228+
"original_request_index": request["original_request_index"],
229+
"buffer": prototype.buffer.from_bytes(buffer),
230+
}
231+
]
232+
233+
234+
async def _get_partial_values(
235+
store: obs.store.ObjectStore,
236+
prototype: BufferPrototype,
237+
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
238+
) -> list[Buffer | None]:
239+
"""Make multiple range requests.
240+
241+
ObjectStore has a `get_ranges` method that will additionally merge nearby ranges,
242+
but it's _per_ file. So we need to split these key_ranges into **per-file** key
243+
ranges, and then reassemble the results in the original order.
244+
245+
We separate into different requests:
246+
247+
- One call to `obstore.get_ranges_async` **per target file**
248+
- One call to `obstore.get_async` for each other request.
249+
"""
250+
key_ranges = list(key_ranges)
251+
per_file_bounded_requests: dict[str, list[BoundedRequest]] = defaultdict(list)
252+
other_requests: list[OtherRequest] = []
253+
254+
for idx, (path, (start, end)) in enumerate(key_ranges):
255+
if start is None:
256+
raise ValueError("Cannot pass `None` for the start of the range request.")
257+
258+
if end is not None:
259+
# This is a bounded request with known start and end byte.
260+
per_file_bounded_requests[path].append(
261+
{"original_request_index": idx, "start": start, "end": end}
262+
)
263+
elif end is None and start < 0:
264+
# Suffix request from the end
265+
other_requests.append(
266+
{"original_request_index": idx, "path": path, "range": {"suffix": abs(start)}}
267+
)
268+
elif end is None and start > 0:
269+
# Offset request to the end
270+
other_requests.append(
271+
{"original_request_index": idx, "path": path, "range": {"offset": start}}
272+
)
273+
else:
274+
raise ValueError(f"Unsupported range input: {start=}, {end=}")
275+
276+
futs: list[Coroutine[Any, Any, list[Response]]] = []
277+
for path, bounded_ranges in per_file_bounded_requests.items():
278+
futs.append(_make_bounded_requests(store, path, bounded_ranges, prototype))
279+
280+
for request in other_requests:
281+
futs.append(_make_other_request(store, request, prototype)) # noqa: PERF401
282+
283+
buffers: list[Buffer | None] = [None] * len(key_ranges)
284+
285+
# TODO: this gather a list of list of Response; not sure if there's a way to
286+
# unpack these lists inside of an `asyncio.gather`?
287+
for responses in await asyncio.gather(*futs):
288+
for resp in responses:
289+
buffers[resp["original_request_index"]] = resp["buffer"]
290+
291+
return buffers

0 commit comments

Comments
 (0)