Skip to content

Commit 8323e15

Browse files
authored
Fix labels multiscales method (#697)
* Add test case for generating labels pyramids * Override default scaling method for labels
1 parent 364d2d3 commit 8323e15

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/spatialdata/models/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
255255
**kwargs,
256256
)
257257

258+
@classmethod
259+
def parse( # noqa: D102
260+
self,
261+
*args: Any,
262+
**kwargs: Any,
263+
) -> DataArray | DataTree:
264+
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
265+
# Override default scaling method to preserve labels
266+
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
267+
return super().parse(*args, **kwargs)
268+
258269

259270
class Labels3DModel(RasterSchema):
260271
dims = DimsSchema((Z, Y, X))
@@ -270,6 +281,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
270281
**kwargs,
271282
)
272283

284+
@classmethod
285+
def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102
286+
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
287+
# Override default scaling method to preserve labels
288+
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
289+
return super().parse(*args, **kwargs)
290+
273291

274292
class Image2DModel(RasterSchema):
275293
dims = DimsSchema((C, Y, X))

tests/models/test_models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,27 @@ def test_raster_schema(
195195
with pytest.raises(ValueError):
196196
model.parse(image, **kwargs)
197197

198+
@pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel])
199+
def test_labels_model_with_multiscales(self, model):
200+
# Passing "scale_factors" should generate multiscales with a "method" appropriate for labels
201+
dims = np.array(model.dims.dims).tolist()
202+
n_dims = len(dims)
203+
204+
# A labels image with one label value 4, that partially covers 2×2 blocks.
205+
# Downsampling with interpolation would produce values 1, 2, 3, 4.
206+
image: ArrayLike = np.array([[0, 0, 0, 0], [0, 4, 4, 4], [4, 4, 4, 4], [0, 4, 4, 4]], dtype=np.uint16)
207+
if n_dims == 3:
208+
image = np.stack([image] * image.shape[0])
209+
actual = model.parse(image, scale_factors=(2,))
210+
assert isinstance(actual, DataTree)
211+
assert actual.children.keys() == {"scale0", "scale1"}
212+
assert actual.scale0.image.dtype == image.dtype
213+
assert actual.scale1.image.dtype == image.dtype
214+
assert set(np.unique(image)) == set(np.unique(actual.scale0.image)), "Scale0 should be preserved"
215+
assert set(np.unique(image)) >= set(
216+
np.unique(actual.scale1.image)
217+
), "Subsequent scales should not have interpolation artifacts"
218+
198219
@pytest.mark.parametrize("model", [ShapesModel])
199220
@pytest.mark.parametrize("path", [POLYGON_PATH, MULTIPOLYGON_PATH, POINT_PATH])
200221
def test_shapes_model(self, model: ShapesModel, path: Path) -> None:

0 commit comments

Comments
 (0)