Skip to content

Commit 707981a

Browse files
authored
allow any object with render method to be used for label distribution (#83)
* allow any object with render method to be used for label distribution * make base public
1 parent 9669067 commit 707981a

8 files changed

Lines changed: 74 additions & 21 deletions

File tree

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from ._distributions._base import BaseDistribution
12
from ._distributions.cosem import CosemLabel
23
from ._distributions.matslines import MatsLines
34
from .fluorophore import Fluorophore
4-
from .sample import Distribution, FluorophoreDistribution, Sample
5+
from .sample import AnyDistribution, FluorophoreDistribution, Sample
56

67
__all__ = [
7-
"MatsLines",
8-
"Sample",
8+
"AnyDistribution",
9+
"BaseDistribution",
910
"CosemLabel",
10-
"Distribution",
11-
"FluorophoreDistribution",
1211
"Fluorophore",
12+
"FluorophoreDistribution",
13+
"MatsLines",
14+
"Sample",
1315
]

src/microsim/schema/sample/_distributions/_base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Annotated, Any, Protocol
55

6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, GetCoreSchemaHandler
7+
from pydantic_core import core_schema
8+
from typing_extensions import runtime_checkable
79

810
if TYPE_CHECKING:
911
from microsim._data_array import xrDataArray
1012
from microsim.schema.backend import NumpyAPI
1113

1214

13-
class _BaseDistribution(BaseModel, ABC):
15+
@runtime_checkable
16+
class Renderable(Protocol):
17+
@abstractmethod
18+
def render(self, space: xrDataArray, xp: NumpyAPI | None = None) -> xrDataArray:
19+
"""Render the distribution into the given space."""
20+
21+
22+
class _IsInstanceAnySer:
23+
@classmethod
24+
def __get_pydantic_core_schema__(
25+
cls, source_type: Any, handler: GetCoreSchemaHandler
26+
) -> core_schema.CoreSchema:
27+
def _validate(obj: Any) -> Any:
28+
if not isinstance(obj, source_type):
29+
raise ValueError(f"Expected {source_type}, got {type(obj)}")
30+
return obj
31+
32+
return core_schema.no_info_before_validator_function(
33+
_validate, core_schema.any_schema()
34+
)
35+
36+
37+
RenderableType = Annotated[Renderable, _IsInstanceAnySer()]
38+
39+
40+
class BaseDistribution(BaseModel, ABC):
1441
@classmethod
1542
def is_random(cls) -> bool:
1643
"""Return True if this distribution generates randomized results."""

src/microsim/schema/sample/_distributions/cosem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from microsim.cosem.models import CosemDataset, CosemImage
88
from microsim.schema.backend import NumpyAPI
99

10-
from ._base import _BaseDistribution
10+
from ._base import BaseDistribution
1111

1212
if TYPE_CHECKING:
1313
from microsim._data_array import xrDataArray
@@ -25,7 +25,7 @@ def _validate_dataset(v: Any) -> CosemDataset:
2525
Dataset = Annotated[CosemDataset, BeforeValidator(_validate_dataset)]
2626

2727

28-
class CosemLabel(_BaseDistribution):
28+
class CosemLabel(BaseDistribution):
2929
"""Renders ground truth based on a specific layer from a COSEM dataset.
3030
3131
Go to https://openorganelle.janelia.org/datasets/ to find a dataset, and then

src/microsim/schema/sample/_distributions/direct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
from microsim.schema.backend import NumpyAPI
66

7-
from ._base import _BaseDistribution
7+
from ._base import BaseDistribution
88

99
if TYPE_CHECKING:
1010
from microsim._data_array import xrDataArray
1111

1212

13-
class FixedArrayTruth(_BaseDistribution):
13+
class FixedArrayTruth(BaseDistribution):
1414
type: Literal["fixed-array"] = "fixed-array"
1515
array: Any
1616

src/microsim/schema/sample/_distributions/matslines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import numpy as np
66

77
from microsim.schema.backend import NumpyAPI
8-
from microsim.schema.sample._distributions._base import _BaseDistribution
8+
from microsim.schema.sample._distributions._base import BaseDistribution
99

1010
if TYPE_CHECKING:
1111
import numpy.typing as npt
1212

1313
from microsim._data_array import xrDataArray
1414

1515

16-
class MatsLines(_BaseDistribution):
16+
class MatsLines(BaseDistribution):
1717
type: Literal["matslines"] = "matslines"
1818
density: float = 1
1919
length: int = 10

src/microsim/schema/sample/sample.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from typing import Any, get_args
2+
from typing import Any
33

44
import numpy as np
55
from pydantic import Field, model_validator
@@ -9,14 +9,15 @@
99
from microsim.schema.backend import NumpyAPI
1010
from microsim.schema.spectrum import Spectrum
1111

12+
from ._distributions._base import Renderable, RenderableType
1213
from ._distributions.cosem import CosemLabel
1314
from ._distributions.direct import FixedArrayTruth
1415
from ._distributions.matslines import MatsLines
1516
from .fluorophore import Fluorophore
1617

17-
Distribution = MatsLines | CosemLabel | FixedArrayTruth
18-
DistributionTypes = get_args(Distribution)
19-
18+
AnyDistribution = MatsLines | CosemLabel | FixedArrayTruth | RenderableType
19+
# TODO: this feels like an unDRY hack
20+
DistributionTypes = MatsLines | CosemLabel | FixedArrayTruth | Renderable
2021

2122
# This is a placeholder fluorophore for when no fluorophore is specified
2223
# it has broad excitation and emission spectra, high extinction coefficient.
@@ -34,7 +35,7 @@
3435

3536

3637
class FluorophoreDistribution(SimBaseModel):
37-
distribution: Distribution = Field(...)
38+
distribution: AnyDistribution = Field(union_mode="left_to_right")
3839
fluorophore: Fluorophore = MOCK_FLUOR
3940
# either a scalar that will be multiplied by the distribution
4041
# (e.g. to increase/decrease concentration of fluorophore)
@@ -47,7 +48,7 @@ def __hash__(self) -> int:
4748

4849
def cache_path(self) -> tuple[str, ...] | None:
4950
if hasattr(self.distribution, "cache_path"):
50-
return self.distribution.cache_path()
51+
return self.distribution.cache_path() # type: ignore [no-any-return]
5152
return None
5253

5354
def render(self, space: xrDataArray, xp: NumpyAPI | None = None) -> xrDataArray:

src/microsim/schema/simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _truth_cache_path(
307307
scale = f'scale{"_".join(str(x) for x in truth_space.scale)}'
308308
conc = f"conc{label.concentration}"
309309
truth_cache = truth_cache / shape / scale / conc
310-
if label.distribution.is_random():
310+
if hasattr(label.distribution, "is_random") and label.distribution.is_random():
311311
truth_cache = truth_cache / f"seed{seed}"
312312
return truth_cache
313313

tests/test_simulation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pytest
6+
from pydantic import ValidationError
67

78
import microsim.schema as ms
89
from microsim.schema.optical_config.lib import FITC
@@ -120,6 +121,28 @@ def test_simulation_from_ground_truth() -> None:
120121
np.testing.assert_array_almost_equal(sim_truth, ground_truth)
121122

122123

124+
def test_simulation_custom_distribution() -> None:
125+
"""Test that we can use any class with a render method as a distribution."""
126+
127+
class BadDistribution:
128+
def have_no_render_method(self, space, xp: ms.NumpyAPI | None = None):
129+
return space
130+
131+
with pytest.raises(ValidationError):
132+
ms.Sample(labels=[BadDistribution()])
133+
134+
class GoodDistribution:
135+
def render(self, space, xp: ms.NumpyAPI | None = None):
136+
return space
137+
138+
sim = ms.Simulation(
139+
truth_space=ms.ShapeScaleSpace(shape=(64, 128, 128), scale=(0.2, 0.1, 0.1)),
140+
output_space={"downscale": 1},
141+
sample=ms.Sample(labels=[GoodDistribution()]),
142+
)
143+
sim.run()
144+
145+
123146
def test_pickle(sim1: ms.Simulation) -> None:
124147
pickled = pickle.dumps(sim1)
125148
assert pickle.loads(pickled) == sim1

0 commit comments

Comments
 (0)