|
3 | 3 | import warnings |
4 | 4 | from typing import TYPE_CHECKING, Any |
5 | 5 |
|
6 | | -from zarr.abc.store import ByteRangeRequest, Store |
| 6 | +from zarr.abc.store import ( |
| 7 | + ByteRequest, |
| 8 | + OffsetByteRequest, |
| 9 | + RangeByteRequest, |
| 10 | + Store, |
| 11 | + SuffixByteRequest, |
| 12 | +) |
7 | 13 | from zarr.storage._common import _dereference_path |
8 | 14 |
|
9 | 15 | if TYPE_CHECKING: |
@@ -199,31 +205,34 @@ async def get( |
199 | 205 | self, |
200 | 206 | key: str, |
201 | 207 | prototype: BufferPrototype, |
202 | | - byte_range: ByteRangeRequest | None = None, |
| 208 | + byte_range: ByteRequest | None = None, |
203 | 209 | ) -> Buffer | None: |
204 | 210 | # docstring inherited |
205 | 211 | if not self._is_open: |
206 | 212 | await self._open() |
207 | 213 | path = _dereference_path(self.path, key) |
208 | 214 |
|
209 | 215 | try: |
210 | | - if byte_range: |
211 | | - # fsspec uses start/end, not start/length |
212 | | - start, length = byte_range |
213 | | - if start is not None and length is not None: |
214 | | - end = start + length |
215 | | - elif length is not None: |
216 | | - end = length |
217 | | - else: |
218 | | - end = None |
219 | | - value = prototype.buffer.from_bytes( |
220 | | - await ( |
221 | | - self.fs._cat_file(path, start=byte_range[0], end=end) |
222 | | - if byte_range |
223 | | - else self.fs._cat_file(path) |
| 216 | + if byte_range is None: |
| 217 | + value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) |
| 218 | + elif isinstance(byte_range, RangeByteRequest): |
| 219 | + value = prototype.buffer.from_bytes( |
| 220 | + await self.fs._cat_file( |
| 221 | + path, |
| 222 | + start=byte_range.start, |
| 223 | + end=byte_range.end, |
| 224 | + ) |
224 | 225 | ) |
225 | | - ) |
226 | | - |
| 226 | + elif isinstance(byte_range, OffsetByteRequest): |
| 227 | + value = prototype.buffer.from_bytes( |
| 228 | + await self.fs._cat_file(path, start=byte_range.offset, end=None) |
| 229 | + ) |
| 230 | + elif isinstance(byte_range, SuffixByteRequest): |
| 231 | + value = prototype.buffer.from_bytes( |
| 232 | + await self.fs._cat_file(path, start=-byte_range.suffix, end=None) |
| 233 | + ) |
| 234 | + else: |
| 235 | + raise ValueError(f"Unexpected byte_range, got {byte_range}.") |
227 | 236 | except self.allowed_exceptions: |
228 | 237 | return None |
229 | 238 | except OSError as e: |
@@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool: |
270 | 279 | async def get_partial_values( |
271 | 280 | self, |
272 | 281 | prototype: BufferPrototype, |
273 | | - key_ranges: Iterable[tuple[str, ByteRangeRequest]], |
| 282 | + key_ranges: Iterable[tuple[str, ByteRequest | None]], |
274 | 283 | ) -> list[Buffer | None]: |
275 | 284 | # docstring inherited |
276 | 285 | if key_ranges: |
277 | | - paths, starts, stops = zip( |
278 | | - *( |
279 | | - ( |
280 | | - _dereference_path(self.path, k[0]), |
281 | | - k[1][0], |
282 | | - ((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None, |
283 | | - ) |
284 | | - for k in key_ranges |
285 | | - ), |
286 | | - strict=False, |
287 | | - ) |
| 286 | + # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. |
| 287 | + key_ranges = list(key_ranges) |
| 288 | + paths: list[str] = [] |
| 289 | + starts: list[int | None] = [] |
| 290 | + stops: list[int | None] = [] |
| 291 | + for key, byte_range in key_ranges: |
| 292 | + paths.append(_dereference_path(self.path, key)) |
| 293 | + if byte_range is None: |
| 294 | + starts.append(None) |
| 295 | + stops.append(None) |
| 296 | + elif isinstance(byte_range, RangeByteRequest): |
| 297 | + starts.append(byte_range.start) |
| 298 | + stops.append(byte_range.end) |
| 299 | + elif isinstance(byte_range, OffsetByteRequest): |
| 300 | + starts.append(byte_range.offset) |
| 301 | + stops.append(None) |
| 302 | + elif isinstance(byte_range, SuffixByteRequest): |
| 303 | + starts.append(-byte_range.suffix) |
| 304 | + stops.append(None) |
| 305 | + else: |
| 306 | + raise ValueError(f"Unexpected byte_range, got {byte_range}.") |
288 | 307 | else: |
289 | 308 | return [] |
290 | 309 | # TODO: expectations for exceptions or missing keys? |
291 | | - res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return") |
| 310 | + res = await self.fs._cat_ranges(paths, starts, stops, on_error="return") |
292 | 311 | # the following is an s3-specific condition we probably don't want to leak |
293 | 312 | res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res] |
294 | 313 | for r in res: |
|
0 commit comments