Skip to content

Commit c9b1d85

Browse files
authored
(fix): make pipeline pickleable (#67)
* (fix): make pipeline pickleable * (fix): add type to helper * (fix): proper type * (fix): dataclasses are not frozen by default * (fix): format * (chore): update docs
1 parent 4ef3899 commit c9b1d85

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ zarr.config.set({
6666
})
6767
```
6868

69+
If the `ZarrsCodecPipeline` is pickled, and then un-pickled, and during that time one of `store_empty_chunks`, `chunk_concurrent_minimum`, `chunk_concurrent_maximum`, or `num_threads` has changed, the newly un-pickled version will pick up the new value. However, one a `ZarrsCodecPipeline` object has been instantiated, these values are then fixed. This may change in the future as guidance from the `zarr` community becomes clear.
70+
6971
## Concurrency
7072

7173
Concurrency can be classified into two types:

python/zarrs/pipeline.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, TypedDict
77

88
import numpy as np
99
from zarr.abc.codec import (
@@ -14,7 +14,7 @@
1414

1515
if TYPE_CHECKING:
1616
from collections.abc import Iterable, Iterator
17-
from typing import Self
17+
from typing import Any, Self
1818

1919
from zarr.abc.store import ByteGetter, ByteSetter
2020
from zarr.core.array_spec import ArraySpec
@@ -32,10 +32,40 @@
3232
)
3333

3434

35-
@dataclass(frozen=True)
35+
def get_codec_pipeline_impl(codec_metadata_json: str) -> CodecPipelineImpl:
36+
return CodecPipelineImpl(
37+
codec_metadata_json,
38+
validate_checksums=config.get("codec_pipeline.validate_checksums", None),
39+
# TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429
40+
store_empty_chunks=config.get("array.write_empty_chunks", None),
41+
chunk_concurrent_minimum=config.get(
42+
"codec_pipeline.chunk_concurrent_minimum", None
43+
),
44+
chunk_concurrent_maximum=config.get(
45+
"codec_pipeline.chunk_concurrent_maximum", None
46+
),
47+
num_threads=config.get("threading.max_workers", None),
48+
)
49+
50+
51+
class ZarrsCodecPipelineState(TypedDict):
52+
codec_metadata_json: str
53+
codecs: tuple[Codec, ...]
54+
55+
56+
@dataclass
3657
class ZarrsCodecPipeline(CodecPipeline):
3758
codecs: tuple[Codec, ...]
3859
impl: CodecPipelineImpl
60+
codec_metadata_json: str
61+
62+
def __getstate__(self) -> ZarrsCodecPipelineState:
63+
return {"codec_metadata_json": self.codec_metadata_json, "codecs": self.codecs}
64+
65+
def __setstate__(self, state: ZarrsCodecPipelineState):
66+
self.codecs = state["codecs"]
67+
self.codec_metadata_json = state["codec_metadata_json"]
68+
self.impl = get_codec_pipeline_impl(self.codec_metadata_json)
3969

4070
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
4171
raise NotImplementedError("evolve_from_array_spec")
@@ -49,22 +79,9 @@ def from_codecs(cls, codecs: Iterable[Codec]) -> Self:
4979
# https://github.com/zarr-developers/zarr-python/issues/2409
5080
# https://github.com/zarr-developers/zarr-python/pull/2429
5181
return cls(
82+
codec_metadata_json=codec_metadata_json,
5283
codecs=tuple(codecs),
53-
impl=CodecPipelineImpl(
54-
codec_metadata_json,
55-
validate_checksums=config.get(
56-
"codec_pipeline.validate_checksums", None
57-
),
58-
# TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429
59-
store_empty_chunks=config.get("array.write_empty_chunks", None),
60-
chunk_concurrent_minimum=config.get(
61-
"codec_pipeline.chunk_concurrent_minimum", None
62-
),
63-
chunk_concurrent_maximum=config.get(
64-
"codec_pipeline.chunk_concurrent_maximum", None
65-
),
66-
num_threads=config.get("threading.max_workers", None),
67-
),
84+
impl=get_codec_pipeline_impl(codec_metadata_json),
6885
)
6986

7087
@property

tests/test_pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import operator
4+
import pickle
45
import tempfile
56
from collections.abc import Callable
67
from contextlib import contextmanager
@@ -229,3 +230,13 @@ def test_ellipsis_indexing_invalid(arr: zarr.Array):
229230
# zarrs-python error: ValueError: operands could not be broadcast together with shapes (4,) (3,)
230231
# numpy error: ValueError: could not broadcast input array from shape (3,) into shape (4,)
231232
arr[2, ...] = stored_value
233+
234+
235+
def test_pickle(arr: zarr.Array, tmp_path: Path):
236+
arr[:] = np.arange(reduce(operator.mul, arr.shape, 1)).reshape(arr.shape)
237+
expected = arr[:]
238+
with Path.open(tmp_path / "arr.pickle", "wb") as f:
239+
pickle.dump(arr._async_array.codec_pipeline, f)
240+
with Path.open(tmp_path / "arr.pickle", "rb") as f:
241+
object.__setattr__(arr._async_array, "codec_pipeline", pickle.load(f))
242+
assert (arr[:] == expected).all()

0 commit comments

Comments
 (0)