|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from numcodecs.vlen import VLenBytes, VLenUTF8 |
| 8 | + |
| 9 | +from zarr.abc.codec import ArrayBytesCodec |
| 10 | +from zarr.core.buffer import Buffer, NDBuffer |
| 11 | +from zarr.core.common import JSON, parse_named_configuration |
| 12 | +from zarr.core.strings import cast_to_string_dtype |
| 13 | +from zarr.registry import register_codec |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from typing import Self |
| 17 | + |
| 18 | + from zarr.core.array_spec import ArraySpec |
| 19 | + |
| 20 | + |
| 21 | +# can use a global because there are no parameters |
| 22 | +_vlen_utf8_codec = VLenUTF8() |
| 23 | +_vlen_bytes_codec = VLenBytes() |
| 24 | + |
| 25 | + |
| 26 | +@dataclass(frozen=True) |
| 27 | +class VLenUTF8Codec(ArrayBytesCodec): |
| 28 | + @classmethod |
| 29 | + def from_dict(cls, data: dict[str, JSON]) -> Self: |
| 30 | + _, configuration_parsed = parse_named_configuration( |
| 31 | + data, "vlen-utf8", require_configuration=False |
| 32 | + ) |
| 33 | + configuration_parsed = configuration_parsed or {} |
| 34 | + return cls(**configuration_parsed) |
| 35 | + |
| 36 | + def to_dict(self) -> dict[str, JSON]: |
| 37 | + return {"name": "vlen-utf8", "configuration": {}} |
| 38 | + |
| 39 | + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: |
| 40 | + return self |
| 41 | + |
| 42 | + async def _decode_single( |
| 43 | + self, |
| 44 | + chunk_bytes: Buffer, |
| 45 | + chunk_spec: ArraySpec, |
| 46 | + ) -> NDBuffer: |
| 47 | + assert isinstance(chunk_bytes, Buffer) |
| 48 | + |
| 49 | + raw_bytes = chunk_bytes.as_array_like() |
| 50 | + decoded = _vlen_utf8_codec.decode(raw_bytes) |
| 51 | + assert decoded.dtype == np.object_ |
| 52 | + decoded.shape = chunk_spec.shape |
| 53 | + # coming out of the code, we know this is safe, so don't issue a warning |
| 54 | + as_string_dtype = cast_to_string_dtype(decoded, safe=True) |
| 55 | + return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) |
| 56 | + |
| 57 | + async def _encode_single( |
| 58 | + self, |
| 59 | + chunk_array: NDBuffer, |
| 60 | + chunk_spec: ArraySpec, |
| 61 | + ) -> Buffer | None: |
| 62 | + assert isinstance(chunk_array, NDBuffer) |
| 63 | + return chunk_spec.prototype.buffer.from_bytes( |
| 64 | + _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) |
| 65 | + ) |
| 66 | + |
| 67 | + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: |
| 68 | + # what is input_byte_length for an object dtype? |
| 69 | + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") |
| 70 | + |
| 71 | + |
| 72 | +@dataclass(frozen=True) |
| 73 | +class VLenBytesCodec(ArrayBytesCodec): |
| 74 | + @classmethod |
| 75 | + def from_dict(cls, data: dict[str, JSON]) -> Self: |
| 76 | + _, configuration_parsed = parse_named_configuration( |
| 77 | + data, "vlen-bytes", require_configuration=False |
| 78 | + ) |
| 79 | + configuration_parsed = configuration_parsed or {} |
| 80 | + return cls(**configuration_parsed) |
| 81 | + |
| 82 | + def to_dict(self) -> dict[str, JSON]: |
| 83 | + return {"name": "vlen-bytes", "configuration": {}} |
| 84 | + |
| 85 | + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: |
| 86 | + return self |
| 87 | + |
| 88 | + async def _decode_single( |
| 89 | + self, |
| 90 | + chunk_bytes: Buffer, |
| 91 | + chunk_spec: ArraySpec, |
| 92 | + ) -> NDBuffer: |
| 93 | + assert isinstance(chunk_bytes, Buffer) |
| 94 | + |
| 95 | + raw_bytes = chunk_bytes.as_array_like() |
| 96 | + decoded = _vlen_bytes_codec.decode(raw_bytes) |
| 97 | + assert decoded.dtype == np.object_ |
| 98 | + decoded.shape = chunk_spec.shape |
| 99 | + return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) |
| 100 | + |
| 101 | + async def _encode_single( |
| 102 | + self, |
| 103 | + chunk_array: NDBuffer, |
| 104 | + chunk_spec: ArraySpec, |
| 105 | + ) -> Buffer | None: |
| 106 | + assert isinstance(chunk_array, NDBuffer) |
| 107 | + return chunk_spec.prototype.buffer.from_bytes( |
| 108 | + _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) |
| 109 | + ) |
| 110 | + |
| 111 | + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: |
| 112 | + # what is input_byte_length for an object dtype? |
| 113 | + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") |
| 114 | + |
| 115 | + |
| 116 | +register_codec("vlen-utf8", VLenUTF8Codec) |
| 117 | +register_codec("vlen-bytes", VLenBytesCodec) |
0 commit comments