|
3 | 3 | import asyncio |
4 | 4 | from collections import defaultdict |
5 | 5 | from collections.abc import Iterable |
6 | | -from typing import TYPE_CHECKING, Any |
| 6 | +from typing import TYPE_CHECKING, Any, TypedDict |
7 | 7 |
|
8 | 8 | import obstore as obs |
9 | 9 |
|
|
15 | 15 | from collections.abc import AsyncGenerator, Coroutine, Iterable |
16 | 16 | from typing import Any |
17 | 17 |
|
18 | | - from obstore import Buffer as ObjectStoreBuffer |
19 | | - from obstore import ListStream, ObjectMeta |
| 18 | + from obstore import ListStream, ObjectMeta, OffsetRange, SuffixRange |
20 | 19 | from obstore.store import ObjectStore as _ObjectStore |
21 | 20 |
|
22 | 21 | from zarr.core.buffer import Buffer, BufferPrototype |
@@ -62,33 +61,7 @@ async def get_partial_values( |
62 | 61 | prototype: BufferPrototype, |
63 | 62 | key_ranges: Iterable[tuple[str, ByteRangeRequest]], |
64 | 63 | ) -> list[Buffer | None]: |
65 | | - # TODO: this is a bit hacky and untested. ObjectStore has a `get_ranges` method |
66 | | - # that will additionally merge nearby ranges, but it's _per_ file. So we need to |
67 | | - # split these key_ranges into **per-file** key ranges, and then reassemble the |
68 | | - # results in the original order. |
69 | | - key_ranges = list(key_ranges) |
70 | | - |
71 | | - per_file_requests: dict[str, list[tuple[int | None, int | None, int]]] = defaultdict(list) |
72 | | - for idx, (path, range_) in enumerate(key_ranges): |
73 | | - per_file_requests[path].append((range_[0], range_[1], idx)) |
74 | | - |
75 | | - futs: list[Coroutine[Any, Any, list[ObjectStoreBuffer]]] = [] |
76 | | - for path, ranges in per_file_requests.items(): |
77 | | - starts = [r[0] for r in ranges] |
78 | | - ends = [r[1] for r in ranges] |
79 | | - fut = obs.get_ranges_async(self.store, path, starts=starts, ends=ends) |
80 | | - futs.append(fut) |
81 | | - |
82 | | - result = await asyncio.gather(*futs) |
83 | | - |
84 | | - output_buffers: list[type[BufferPrototype]] = [b""] * len(key_ranges) |
85 | | - for per_file_request, buffers in zip(per_file_requests.items(), result, strict=True): |
86 | | - path, ranges = per_file_request |
87 | | - for buffer, ranges_ in zip(buffers, ranges, strict=True): |
88 | | - initial_index = ranges_[2] |
89 | | - output_buffers[initial_index] = prototype.buffer.from_buffer(memoryview(buffer)) |
90 | | - |
91 | | - return output_buffers |
| 64 | + return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) |
92 | 65 |
|
93 | 66 | async def exists(self, key: str) -> bool: |
94 | 67 | try: |
@@ -163,3 +136,156 @@ async def _transform_list_dir( |
163 | 136 | # Yield this item if "/" does not exist after the prefix. |
164 | 137 | if "/" not in item["path"][prefix_len:]: |
165 | 138 | yield item["path"] |
| 139 | + |
| 140 | + |
| 141 | +class BoundedRequest(TypedDict): |
| 142 | + """Range request with a known start and end byte. |
| 143 | +
|
| 144 | + These requests can be multiplexed natively on the Rust side with |
| 145 | + `obstore.get_ranges_async`. |
| 146 | + """ |
| 147 | + |
| 148 | + original_request_index: int |
| 149 | + """The positional index in the original key_ranges input""" |
| 150 | + |
| 151 | + start: int |
| 152 | + """Start byte offset.""" |
| 153 | + |
| 154 | + end: int |
| 155 | + """End byte offset.""" |
| 156 | + |
| 157 | + |
| 158 | +class OtherRequest(TypedDict): |
| 159 | + """Offset or suffix range requests. |
| 160 | +
|
| 161 | + These requests cannot be concurrent on the Rust side, and each need their own call |
| 162 | + to `obstore.get_async`, passing in the `range` parameter. |
| 163 | + """ |
| 164 | + |
| 165 | + original_request_index: int |
| 166 | + """The positional index in the original key_ranges input""" |
| 167 | + |
| 168 | + path: str |
| 169 | + """The path to request from.""" |
| 170 | + |
| 171 | + range: OffsetRange | SuffixRange |
| 172 | + """The range request type.""" |
| 173 | + |
| 174 | + |
| 175 | +class Response(TypedDict): |
| 176 | + """A response buffer associated with the original index that it should be restored to.""" |
| 177 | + |
| 178 | + original_request_index: int |
| 179 | + """The positional index in the original key_ranges input""" |
| 180 | + |
| 181 | + buffer: Buffer |
| 182 | + """The buffer returned from obstore's range request.""" |
| 183 | + |
| 184 | + |
| 185 | +async def _make_bounded_requests( |
| 186 | + store: obs.store.ObjectStore, |
| 187 | + path: str, |
| 188 | + requests: list[BoundedRequest], |
| 189 | + prototype: BufferPrototype, |
| 190 | +) -> list[Response]: |
| 191 | + """Make all bounded requests for a specific file. |
| 192 | +
|
| 193 | + `obstore.get_ranges_async` allows for making concurrent requests for multiple ranges |
| 194 | + within a single file, and will e.g. merge concurrent requests. This only uses one |
| 195 | + single Python coroutine. |
| 196 | + """ |
| 197 | + |
| 198 | + starts = [r["start"] for r in requests] |
| 199 | + ends = [r["end"] for r in requests] |
| 200 | + responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends) |
| 201 | + |
| 202 | + buffer_responses: list[Response] = [] |
| 203 | + for request, response in zip(requests, responses, strict=True): |
| 204 | + buffer_responses.append( |
| 205 | + { |
| 206 | + "original_request_index": request["original_request_index"], |
| 207 | + "buffer": prototype.buffer.from_bytes(memoryview(response)), |
| 208 | + } |
| 209 | + ) |
| 210 | + |
| 211 | + return buffer_responses |
| 212 | + |
| 213 | + |
| 214 | +async def _make_other_request( |
| 215 | + store: obs.store.ObjectStore, |
| 216 | + request: OtherRequest, |
| 217 | + prototype: BufferPrototype, |
| 218 | +) -> list[Response]: |
| 219 | + """Make suffix or offset requests. |
| 220 | +
|
| 221 | + We return a `list[Response]` for symmetry with `_make_bounded_requests` so that all |
| 222 | + futures can be gathered together. |
| 223 | + """ |
| 224 | + resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) |
| 225 | + buffer = await resp.bytes_async() |
| 226 | + return [ |
| 227 | + { |
| 228 | + "original_request_index": request["original_request_index"], |
| 229 | + "buffer": prototype.buffer.from_bytes(buffer), |
| 230 | + } |
| 231 | + ] |
| 232 | + |
| 233 | + |
| 234 | +async def _get_partial_values( |
| 235 | + store: obs.store.ObjectStore, |
| 236 | + prototype: BufferPrototype, |
| 237 | + key_ranges: Iterable[tuple[str, ByteRangeRequest]], |
| 238 | +) -> list[Buffer | None]: |
| 239 | + """Make multiple range requests. |
| 240 | +
|
| 241 | + ObjectStore has a `get_ranges` method that will additionally merge nearby ranges, |
| 242 | + but it's _per_ file. So we need to split these key_ranges into **per-file** key |
| 243 | + ranges, and then reassemble the results in the original order. |
| 244 | +
|
| 245 | + We separate into different requests: |
| 246 | +
|
| 247 | + - One call to `obstore.get_ranges_async` **per target file** |
| 248 | + - One call to `obstore.get_async` for each other request. |
| 249 | + """ |
| 250 | + key_ranges = list(key_ranges) |
| 251 | + per_file_bounded_requests: dict[str, list[BoundedRequest]] = defaultdict(list) |
| 252 | + other_requests: list[OtherRequest] = [] |
| 253 | + |
| 254 | + for idx, (path, (start, end)) in enumerate(key_ranges): |
| 255 | + if start is None: |
| 256 | + raise ValueError("Cannot pass `None` for the start of the range request.") |
| 257 | + |
| 258 | + if end is not None: |
| 259 | + # This is a bounded request with known start and end byte. |
| 260 | + per_file_bounded_requests[path].append( |
| 261 | + {"original_request_index": idx, "start": start, "end": end} |
| 262 | + ) |
| 263 | + elif end is None and start < 0: |
| 264 | + # Suffix request from the end |
| 265 | + other_requests.append( |
| 266 | + {"original_request_index": idx, "path": path, "range": {"suffix": abs(start)}} |
| 267 | + ) |
| 268 | + elif end is None and start > 0: |
| 269 | + # Offset request to the end |
| 270 | + other_requests.append( |
| 271 | + {"original_request_index": idx, "path": path, "range": {"offset": start}} |
| 272 | + ) |
| 273 | + else: |
| 274 | + raise ValueError(f"Unsupported range input: {start=}, {end=}") |
| 275 | + |
| 276 | + futs: list[Coroutine[Any, Any, list[Response]]] = [] |
| 277 | + for path, bounded_ranges in per_file_bounded_requests.items(): |
| 278 | + futs.append(_make_bounded_requests(store, path, bounded_ranges, prototype)) |
| 279 | + |
| 280 | + for request in other_requests: |
| 281 | + futs.append(_make_other_request(store, request, prototype)) # noqa: PERF401 |
| 282 | + |
| 283 | + buffers: list[Buffer | None] = [None] * len(key_ranges) |
| 284 | + |
| 285 | + # TODO: this gather a list of list of Response; not sure if there's a way to |
| 286 | + # unpack these lists inside of an `asyncio.gather`? |
| 287 | + for responses in await asyncio.gather(*futs): |
| 288 | + for resp in responses: |
| 289 | + buffers[resp["original_request_index"]] = resp["buffer"] |
| 290 | + |
| 291 | + return buffers |
0 commit comments