diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 7aeb0b2c0..22e7fc068 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -401,6 +401,192 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: **kwargs, ) + @classmethod + def parse( # noqa: D102 + cls, + data: ArrayLike | DataArray | DaskArray, + dims: Sequence[str] | None = None, + c_coords: str | list[str] | None = None, + transformations: MappingToCoordinateSystem_t | None = None, + scale_factors: ScaleFactors_t | None = None, + method: Methods | None = None, + chunks: Chunks_t | None = None, + delay_z_scaling: bool = False, + **kwargs: Any, + ) -> DataArray | DataTree: + r""" + Validate (or parse) 3D image data. + + Parameters + ---------- + data + Data to validate (or parse). The shape of the data should be czyx for 3D images. + dims + Dimensions of the data (e.g. ['c', 'z', 'y', 'x'] for 3D image data). If the data is a + :class:`xarray.DataArray`, the dimensions can also be inferred from the data. If the dimensions + are not in the order czyx, the data will be transposed to match the order. + c_coords : str | list[str] | None + Channel names of image data. Must be equal to the length of dimension 'c'. + transformations + Dictionary of transformations to apply to the data. The key is the name of the target coordinate + system, the value is the transformation to apply. By default, a single `Identity` transformation + mapping to the `"global"` coordinate system is applied. + scale_factors + Scale factors to apply to construct a multiscale image (:class:`datatree.DataTree`). + If `None`, a :class:`xarray.DataArray` is returned instead. + Importantly, each scale factor is relative to the previous scale factor. For example, if the + scale factors are `[2, 2, 2]`, the returned multiscale image will have 4 scales. The original + image and then the 2x, 4x and 8x downsampled images. + method + Method to use for multiscale downsampling (default is `'nearest'`). Please refer to + :class:`multiscale_spatial_image.to_multiscale` for details. + chunks + Chunks to use for dask array. + delay_z_scaling : bool + If True, delay scaling along the Z dimension. When enabled, only X and Y dimensions are scaled + by a factor of 2 until min(X_size, Y_size) becomes less than the original Z_size. Only then + does Z dimension scaling begin. This is useful for 3D image visualization where you want to + preserve Z resolution longer. When this option is used, the `scale_factors` parameter is ignored + and computed automatically. Default is False. + kwargs + Additional arguments for :func:`to_spatial_image`. In particular the `c_coords` kwargs argument + (an iterable) can be used to set the channel coordinates for image data. + + Returns + ------- + :class:`xarray.DataArray` or :class:`datatree.DataTree` + + Notes + ----- + When `delay_z_scaling` is enabled, the method automatically computes scale factors that scale only + X and Y dimensions by 2 until min(X_size, Y_size) < original_Z_size, then scales all dimensions + X, Y, Z by 2 for subsequent levels. + """ + if delay_z_scaling: + if scale_factors is not None: + import warnings + warnings.warn( + "When delay_z_scaling=True, the scale_factors parameter is ignored and computed automatically.", + UserWarning, + stacklevel=2 + ) + + # Compute scale factors for delayed Z scaling + scale_factors = cls._compute_delayed_z_scale_factors(data, dims) + + return super().parse( + data=data, + dims=dims, + c_coords=c_coords, + transformations=transformations, + scale_factors=scale_factors, + method=method, + chunks=chunks, + **kwargs, + ) + + @classmethod + def _compute_delayed_z_scale_factors( + cls, + data: ArrayLike | DataArray | DaskArray, + dims: Sequence[str] | None = None, + ) -> ScaleFactors_t: + """ + Compute scale factors for delayed Z dimension scaling. + + Scale only X and Y dimensions by 2 until min(X_size, Y_size) < original_Z_size, + then scale all dimensions X, Y, Z by 2. + + Parameters + ---------- + data + The input data to analyze for computing scale factors. + dims + Dimensions of the data. If None, uses the model's default dims. + + Returns + ------- + List of scale factor dictionaries for each scale level. + """ + # Get data shape + if isinstance(data, DataArray): + data_dims = data.dims if dims is None else dims + shape = data.shape + else: + # For numpy arrays or dask arrays + if dims is None: + data_dims = cls.dims.dims + else: + data_dims = dims + shape = data.shape + + # Create mapping from dimension name to size + dim_to_size = dict(zip(data_dims, shape)) + + # Get original Z size + original_z_size = dim_to_size.get('z', dim_to_size.get(Z, 0)) + if original_z_size == 0: + raise ValueError("Z dimension not found in data") + + # Initialize current sizes + current_x = dim_to_size.get('x', dim_to_size.get(X, 0)) + current_y = dim_to_size.get('y', dim_to_size.get(Y, 0)) + current_z = original_z_size + + if current_x == 0 or current_y == 0: + raise ValueError("X or Y dimension not found in data") + + scale_factors = [] + + # Phase 1: Scale only X and Y until min(X_next, Y_next) < original_Z + while True: + next_x = current_x // 2 + next_y = current_y // 2 + + # Check if we should stop scaling only X,Y + if min(next_x, next_y) < original_z_size: + break + + # Scale factors for this level - only scale X and Y by 2 + scale_factor_dict = {} + for dim in data_dims: + if dim in ('x', X): + scale_factor_dict[dim] = 2 + elif dim in ('y', Y): + scale_factor_dict[dim] = 2 + else: + scale_factor_dict[dim] = 1 # No scaling for C and Z + + scale_factors.append(scale_factor_dict) + + # Update current sizes + current_x = next_x + current_y = next_y + + # Phase 2: Scale all dimensions (X, Y, Z) by 2 for subsequent levels + # Continue until the image becomes very small + max_levels = 10 # Reasonable limit to prevent infinite loops + level_count = len(scale_factors) + + while level_count < max_levels and min(current_x // 2, current_y // 2) >= 1: + # Scale factors for this level - scale all dimensions by 2 + scale_factor_dict = {} + for dim in data_dims: + if dim in ('c', C): + scale_factor_dict[dim] = 1 # Never scale channels + else: + scale_factor_dict[dim] = 2 # Scale Z, Y, X + + scale_factors.append(scale_factor_dict) + + # Update current sizes + current_x //= 2 + current_y //= 2 + current_z //= 2 + level_count += 1 + + return scale_factors + class ShapesModel: GEOMETRY_KEY = "geometry" diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e08268a5e..f7dead90f 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -782,3 +782,102 @@ def test_warning_on_large_chunks(): assert len(w) == 1, "Warning should be raised for large chunk size" assert issubclass(w[-1].category, UserWarning) assert "Detected chunks larger than:" in str(w[-1].message) + + +def test_image3d_delayed_z_scaling(): + """Test delayed Z dimension scaling for 3D images.""" + # Create test data similar to the issue example + # Issue shows TCZYX [1, 12, 194, 3181, 4045] + # For Image3DModel (CZYX), we use proportionally smaller data: [2, 19, 32, 40] + # where Z=19, Y=32, X=40, and we expect Z to be preserved until min(X,Y) < 19 + data = np.random.random((2, 19, 32, 40)).astype(np.float32) + + # Test that delay_z_scaling=False works as normal (backward compatibility) + standard_result = Image3DModel.parse(data, delay_z_scaling=False) + assert isinstance(standard_result, DataArray) + + # Test with delay_z_scaling=True and verify the computed scale factors + delayed_result = Image3DModel.parse(data, delay_z_scaling=True) + assert isinstance(delayed_result, DataTree) + + # Check that we have multiple scales + scale_keys = list(delayed_result.keys()) + assert len(scale_keys) > 1 + assert all(key.startswith("scale") for key in scale_keys) + + # Verify the shapes follow the expected pattern + original_z = 19 # Z dimension from our test data + scales = [] + for scale_key in sorted(scale_keys): + scale_data = delayed_result[scale_key]["image"] + scales.append(scale_data.shape) + + # Check that first few scales preserve Z dimension + # Original: (2, 19, 32, 40) where C=2, Z=19, Y=32, X=40 + # Scale 1: (2, 19, 16, 20) - only X,Y scaled by 2 + # Scale 2: (2, 19, 8, 10) - only X,Y scaled by 2 + # Eventually Z should start scaling when min(X,Y) < original_Z=19 + + assert scales[0] == (2, 19, 32, 40) # Original scale + + # Find where Z scaling starts + z_scaling_started = False + for i in range(1, len(scales)): + prev_shape = scales[i-1] + curr_shape = scales[i] + + # Check if Z dimension was scaled + if curr_shape[1] < prev_shape[1]: + z_scaling_started = True + # Once Z scaling starts, min(X,Y) should be < original_Z + min_xy = min(curr_shape[2], curr_shape[3]) + assert min_xy < original_z, f"Z scaling started too early at scale {i}" + break + + # Verify that before Z scaling starts, only X and Y are scaled + for i in range(1, len(scales)): + prev_shape = scales[i-1] + curr_shape = scales[i] + + if curr_shape[1] == prev_shape[1]: # Z not scaled yet + # X and Y should be scaled by 2 + assert curr_shape[2] <= prev_shape[2] // 2 + 1 # Account for rounding + assert curr_shape[3] <= prev_shape[3] // 2 + 1 + # C should never change + assert curr_shape[0] == prev_shape[0] + else: + # Z scaling has started + break + + +def test_image3d_delayed_z_scaling_with_warning(): + """Test that providing scale_factors with delay_z_scaling=True produces a warning.""" + data = np.random.random((2, 19, 32, 40)).astype(np.float32) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = Image3DModel.parse(data, scale_factors=[2, 2], delay_z_scaling=True) + + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "scale_factors parameter is ignored" in str(w[0].message) + + # Should still work and return a DataTree + assert isinstance(result, DataTree) + + +def test_image3d_delayed_z_scaling_edge_cases(): + """Test edge cases for delayed Z scaling.""" + # Test with very small image where Z is already larger than X,Y + small_data = np.random.random((1, 20, 8, 6)).astype(np.float32) # Z=20, Y=8, X=6 + + result = Image3DModel.parse(small_data, delay_z_scaling=True) + + # Should still work and create multiscale + assert isinstance(result, DataTree) + + # Test with invalid data (missing dimensions) + with pytest.raises(ValueError, match="dimension not found"): + # Create 2D data and try to use it with Image3DModel + invalid_data = np.random.random((10, 10)) + Image3DModel.parse(invalid_data, delay_z_scaling=True)