|
15 | 15 | from typing import Any |
16 | 16 |
|
17 | 17 | import arviz as az |
18 | | -import numcodecs |
19 | 18 | import numpy as np |
20 | 19 | import xarray as xr |
21 | | -import zarr |
22 | 20 |
|
23 | 21 | from arviz.data.base import make_attrs |
24 | 22 | from arviz.data.inference_data import WARMUP_TAG |
25 | | -from numcodecs.abc import Codec |
26 | 23 | from pytensor.tensor.variable import TensorVariable |
27 | 24 |
|
28 | 25 | import pymc |
|
44 | 41 | from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name |
45 | 42 |
|
46 | 43 | try: |
| 44 | + import numcodecs |
| 45 | + import zarr |
| 46 | + |
| 47 | + from numcodecs.abc import Codec |
| 48 | + from zarr import Group |
47 | 49 | from zarr.storage import BaseStore, default_compressor |
48 | 50 | from zarr.sync import Synchronizer |
49 | 51 |
|
50 | 52 | _zarr_available = True |
51 | 53 | except ImportError: |
| 54 | + from typing import TYPE_CHECKING, TypeVar |
| 55 | + |
| 56 | + if not TYPE_CHECKING: |
| 57 | + Codec = TypeVar("Codec") |
| 58 | + Group = TypeVar("Group") |
| 59 | + BaseStore = TypeVar("BaseStore") |
| 60 | + Synchronizer = TypeVar("Synchronizer") |
52 | 61 | _zarr_available = False |
53 | 62 |
|
54 | 63 |
|
@@ -243,7 +252,7 @@ def flush(self): |
243 | 252 |
|
244 | 253 | def get_initial_fill_value_and_codec( |
245 | 254 | dtype: Any, |
246 | | -) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: |
| 255 | +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]: |
247 | 256 | _dtype = np.dtype(dtype) |
248 | 257 | fill_value: FILL_VALUE_TYPE = None |
249 | 258 | codec = None |
@@ -366,27 +375,27 @@ def groups(self) -> list[str]: |
366 | 375 | return [str(group_name) for group_name, _ in self.root.groups()] |
367 | 376 |
|
368 | 377 | @property |
369 | | - def posterior(self) -> zarr.Group: |
| 378 | + def posterior(self) -> Group: |
370 | 379 | return self.root.posterior |
371 | 380 |
|
372 | 381 | @property |
373 | | - def unconstrained_posterior(self) -> zarr.Group: |
| 382 | + def unconstrained_posterior(self) -> Group: |
374 | 383 | return self.root.unconstrained_posterior |
375 | 384 |
|
376 | 385 | @property |
377 | | - def sample_stats(self) -> zarr.Group: |
| 386 | + def sample_stats(self) -> Group: |
378 | 387 | return self.root.sample_stats |
379 | 388 |
|
380 | 389 | @property |
381 | | - def constant_data(self) -> zarr.Group: |
| 390 | + def constant_data(self) -> Group: |
382 | 391 | return self.root.constant_data |
383 | 392 |
|
384 | 393 | @property |
385 | | - def observed_data(self) -> zarr.Group: |
| 394 | + def observed_data(self) -> Group: |
386 | 395 | return self.root.observed_data |
387 | 396 |
|
388 | 397 | @property |
389 | | - def _sampling_state(self) -> zarr.Group: |
| 398 | + def _sampling_state(self) -> Group: |
390 | 399 | return self.root._sampling_state |
391 | 400 |
|
392 | 401 | def init_trace( |
@@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int): |
646 | 655 |
|
647 | 656 | def init_group_with_empty( |
648 | 657 | self, |
649 | | - group: zarr.Group, |
| 658 | + group: Group, |
650 | 659 | var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], |
651 | 660 | chains: int, |
652 | 661 | draws: int, |
653 | 662 | extra_var_attrs: dict | None = None, |
654 | | - ) -> zarr.Group: |
| 663 | + ) -> Group: |
655 | 664 | group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} |
656 | 665 | for name, (_dtype, shape) in var_dtype_and_shape.items(): |
657 | 666 | fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) |
@@ -689,8 +698,8 @@ def init_group_with_empty( |
689 | 698 | array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) |
690 | 699 | return group |
691 | 700 |
|
692 | | - def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: |
693 | | - group: zarr.Group | None = None |
| 701 | + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None: |
| 702 | + group: Group | None = None |
694 | 703 | if data_dict: |
695 | 704 | group_coords = {} |
696 | 705 | group = self.root.create_group(name=name, overwrite=True) |
|
0 commit comments