Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,192 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
**kwargs,
)

@classmethod
def parse( # noqa: D102
cls,
data: ArrayLike | DataArray | DaskArray,
dims: Sequence[str] | None = None,
c_coords: str | list[str] | None = None,
transformations: MappingToCoordinateSystem_t | None = None,
scale_factors: ScaleFactors_t | None = None,
method: Methods | None = None,
chunks: Chunks_t | None = None,
delay_z_scaling: bool = False,
**kwargs: Any,
) -> DataArray | DataTree:
r"""
Validate (or parse) 3D image data.

Parameters
----------
data
Data to validate (or parse). The shape of the data should be czyx for 3D images.
dims
Dimensions of the data (e.g. ['c', 'z', 'y', 'x'] for 3D image data). If the data is a
:class:`xarray.DataArray`, the dimensions can also be inferred from the data. If the dimensions
are not in the order czyx, the data will be transposed to match the order.
c_coords : str | list[str] | None
Channel names of image data. Must be equal to the length of dimension 'c'.
transformations
Dictionary of transformations to apply to the data. The key is the name of the target coordinate
system, the value is the transformation to apply. By default, a single `Identity` transformation
mapping to the `"global"` coordinate system is applied.
scale_factors
Scale factors to apply to construct a multiscale image (:class:`datatree.DataTree`).
If `None`, a :class:`xarray.DataArray` is returned instead.
Importantly, each scale factor is relative to the previous scale factor. For example, if the
scale factors are `[2, 2, 2]`, the returned multiscale image will have 4 scales. The original
image and then the 2x, 4x and 8x downsampled images.
method
Method to use for multiscale downsampling (default is `'nearest'`). Please refer to
:class:`multiscale_spatial_image.to_multiscale` for details.
chunks
Chunks to use for dask array.
delay_z_scaling : bool
If True, delay scaling along the Z dimension. When enabled, only X and Y dimensions are scaled
by a factor of 2 until min(X_size, Y_size) becomes less than the original Z_size. Only then
does Z dimension scaling begin. This is useful for 3D image visualization where you want to
preserve Z resolution longer. When this option is used, the `scale_factors` parameter is ignored
and computed automatically. Default is False.
kwargs
Additional arguments for :func:`to_spatial_image`. In particular the `c_coords` kwargs argument
(an iterable) can be used to set the channel coordinates for image data.

Returns
-------
:class:`xarray.DataArray` or :class:`datatree.DataTree`

Notes
-----
When `delay_z_scaling` is enabled, the method automatically computes scale factors that scale only
X and Y dimensions by 2 until min(X_size, Y_size) < original_Z_size, then scales all dimensions
X, Y, Z by 2 for subsequent levels.
"""
if delay_z_scaling:
if scale_factors is not None:
import warnings
warnings.warn(
"When delay_z_scaling=True, the scale_factors parameter is ignored and computed automatically.",
UserWarning,
stacklevel=2
)

# Compute scale factors for delayed Z scaling
scale_factors = cls._compute_delayed_z_scale_factors(data, dims)

return super().parse(
data=data,
dims=dims,
c_coords=c_coords,
transformations=transformations,
scale_factors=scale_factors,
method=method,
chunks=chunks,
**kwargs,
)

@classmethod
def _compute_delayed_z_scale_factors(
cls,
data: ArrayLike | DataArray | DaskArray,
dims: Sequence[str] | None = None,
) -> ScaleFactors_t:
"""
Compute scale factors for delayed Z dimension scaling.

Scale only X and Y dimensions by 2 until min(X_size, Y_size) < original_Z_size,
then scale all dimensions X, Y, Z by 2.

Parameters
----------
data
The input data to analyze for computing scale factors.
dims
Dimensions of the data. If None, uses the model's default dims.

Returns
-------
List of scale factor dictionaries for each scale level.
"""
# Get data shape
if isinstance(data, DataArray):
data_dims = data.dims if dims is None else dims
shape = data.shape
else:
# For numpy arrays or dask arrays
if dims is None:
data_dims = cls.dims.dims
else:
data_dims = dims
shape = data.shape

# Create mapping from dimension name to size
dim_to_size = dict(zip(data_dims, shape))

# Get original Z size
original_z_size = dim_to_size.get('z', dim_to_size.get(Z, 0))
if original_z_size == 0:
raise ValueError("Z dimension not found in data")

# Initialize current sizes
current_x = dim_to_size.get('x', dim_to_size.get(X, 0))
current_y = dim_to_size.get('y', dim_to_size.get(Y, 0))
current_z = original_z_size

if current_x == 0 or current_y == 0:
raise ValueError("X or Y dimension not found in data")

scale_factors = []

# Phase 1: Scale only X and Y until min(X_next, Y_next) < original_Z
while True:
next_x = current_x // 2
next_y = current_y // 2

# Check if we should stop scaling only X,Y
if min(next_x, next_y) < original_z_size:
break

# Scale factors for this level - only scale X and Y by 2
scale_factor_dict = {}
for dim in data_dims:
if dim in ('x', X):
scale_factor_dict[dim] = 2
elif dim in ('y', Y):
scale_factor_dict[dim] = 2
else:
scale_factor_dict[dim] = 1 # No scaling for C and Z

scale_factors.append(scale_factor_dict)

# Update current sizes
current_x = next_x
current_y = next_y

# Phase 2: Scale all dimensions (X, Y, Z) by 2 for subsequent levels
# Continue until the image becomes very small
max_levels = 10 # Reasonable limit to prevent infinite loops
level_count = len(scale_factors)

while level_count < max_levels and min(current_x // 2, current_y // 2) >= 1:
# Scale factors for this level - scale all dimensions by 2
scale_factor_dict = {}
for dim in data_dims:
if dim in ('c', C):
scale_factor_dict[dim] = 1 # Never scale channels
else:
scale_factor_dict[dim] = 2 # Scale Z, Y, X

scale_factors.append(scale_factor_dict)

# Update current sizes
current_x //= 2
current_y //= 2
current_z //= 2
level_count += 1

return scale_factors


class ShapesModel:
GEOMETRY_KEY = "geometry"
Expand Down
99 changes: 99 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,3 +782,102 @@ def test_warning_on_large_chunks():
assert len(w) == 1, "Warning should be raised for large chunk size"
assert issubclass(w[-1].category, UserWarning)
assert "Detected chunks larger than:" in str(w[-1].message)


def test_image3d_delayed_z_scaling():
"""Test delayed Z dimension scaling for 3D images."""
# Create test data similar to the issue example
# Issue shows TCZYX [1, 12, 194, 3181, 4045]
# For Image3DModel (CZYX), we use proportionally smaller data: [2, 19, 32, 40]
# where Z=19, Y=32, X=40, and we expect Z to be preserved until min(X,Y) < 19
data = np.random.random((2, 19, 32, 40)).astype(np.float32)

# Test that delay_z_scaling=False works as normal (backward compatibility)
standard_result = Image3DModel.parse(data, delay_z_scaling=False)
assert isinstance(standard_result, DataArray)

# Test with delay_z_scaling=True and verify the computed scale factors
delayed_result = Image3DModel.parse(data, delay_z_scaling=True)
assert isinstance(delayed_result, DataTree)

# Check that we have multiple scales
scale_keys = list(delayed_result.keys())
assert len(scale_keys) > 1
assert all(key.startswith("scale") for key in scale_keys)

# Verify the shapes follow the expected pattern
original_z = 19 # Z dimension from our test data
scales = []
for scale_key in sorted(scale_keys):
scale_data = delayed_result[scale_key]["image"]
scales.append(scale_data.shape)

# Check that first few scales preserve Z dimension
# Original: (2, 19, 32, 40) where C=2, Z=19, Y=32, X=40
# Scale 1: (2, 19, 16, 20) - only X,Y scaled by 2
# Scale 2: (2, 19, 8, 10) - only X,Y scaled by 2
# Eventually Z should start scaling when min(X,Y) < original_Z=19

assert scales[0] == (2, 19, 32, 40) # Original scale

# Find where Z scaling starts
z_scaling_started = False
for i in range(1, len(scales)):
prev_shape = scales[i-1]
curr_shape = scales[i]

# Check if Z dimension was scaled
if curr_shape[1] < prev_shape[1]:
z_scaling_started = True
# Once Z scaling starts, min(X,Y) should be < original_Z
min_xy = min(curr_shape[2], curr_shape[3])
assert min_xy < original_z, f"Z scaling started too early at scale {i}"
break

# Verify that before Z scaling starts, only X and Y are scaled
for i in range(1, len(scales)):
prev_shape = scales[i-1]
curr_shape = scales[i]

if curr_shape[1] == prev_shape[1]: # Z not scaled yet
# X and Y should be scaled by 2
assert curr_shape[2] <= prev_shape[2] // 2 + 1 # Account for rounding
assert curr_shape[3] <= prev_shape[3] // 2 + 1
# C should never change
assert curr_shape[0] == prev_shape[0]
else:
# Z scaling has started
break


def test_image3d_delayed_z_scaling_with_warning():
"""Test that providing scale_factors with delay_z_scaling=True produces a warning."""
data = np.random.random((2, 19, 32, 40)).astype(np.float32)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = Image3DModel.parse(data, scale_factors=[2, 2], delay_z_scaling=True)

assert len(w) == 1
assert issubclass(w[0].category, UserWarning)
assert "scale_factors parameter is ignored" in str(w[0].message)

# Should still work and return a DataTree
assert isinstance(result, DataTree)


def test_image3d_delayed_z_scaling_edge_cases():
"""Test edge cases for delayed Z scaling."""
# Test with very small image where Z is already larger than X,Y
small_data = np.random.random((1, 20, 8, 6)).astype(np.float32) # Z=20, Y=8, X=6

result = Image3DModel.parse(small_data, delay_z_scaling=True)

# Should still work and create multiscale
assert isinstance(result, DataTree)

# Test with invalid data (missing dimensions)
with pytest.raises(ValueError, match="dimension not found"):
# Create 2D data and try to use it with Image3DModel
invalid_data = np.random.random((10, 10))
Image3DModel.parse(invalid_data, delay_z_scaling=True)
Loading