Skip to content
47 changes: 37 additions & 10 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import SpatialElement, TableModel, get_table_keys
from spatialdata.models import TableModel, get_table_keys
from spatialdata.transformations import (
get_transformation,
remove_transformation,
set_transformation,
)

__all__ = [
"concatenate",
Expand Down Expand Up @@ -81,7 +86,8 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
merge_coordinate_systems_on_name: bool = False,
attrs_merge: (StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None) = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -141,6 +147,7 @@ def concatenate(
rename_tables=not concatenate_tables,
rename_obs_names=obs_names_make_unique and concatenate_tables,
modify_tables_inplace=modify_tables_inplace,
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)

ERROR_STR = (
Expand Down Expand Up @@ -222,12 +229,14 @@ def _fix_ensure_unique_element_names(
rename_tables: bool,
rename_obs_names: bool,
modify_tables_inplace: bool,
merge_coordinate_systems_on_name: bool,
) -> list[SpatialData]:
elements_by_sdata: list[dict[str, SpatialElement]] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactoring is fine

tables_by_sdata: list[dict[str, AnnData]] = []
sdatas_fixed = []
for suffix, sdata in sdatas.items():
# Create new elements dictionary with suffixed names
elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()}
elements_by_sdata.append(elements)

# Handle tables with suffix
tables = {}
for name, table in sdata.tables.items():
if not modify_tables_inplace:
Expand All @@ -251,9 +260,27 @@ def _fix_ensure_unique_element_names(
# fix the table name
new_name = f"{name}-{suffix}" if rename_tables else name
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True):
sdata = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata)

# Create new SpatialData object with suffixed elements and tables
sdata_fixed = SpatialData.init_from_elements(elements, tables=tables)

# Handle coordinate systems and transformations
for element_name, element in elements.items():
# Get the original element from the input sdata
original_name = element_name.replace(f"-{suffix}", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this, we already know el when constructing the elements dict.

original_element = sdata.get(original_name)

# Get transformations from original element
transformations = get_transformation(original_element, get_all=True)
assert isinstance(transformations, dict)

# Remove any existing transformations from the new element
remove_transformation(element, remove_all=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code removes all the transformations, then checks (at every iteration) for the value of merge_coordinate_system_on_name, and when it is true it restores the original transformation.

This is convoluted and note needed, as one could check merge_coordinate_system_on_name before calling remove_transformation.


# Set new transformations with suffixed coordinate system names
for cs, t in transformations.items():
new_cs = cs if merge_coordinate_systems_on_name else f"{cs}-{suffix}"
set_transformation(element, t, to_coordinate_system=new_cs)

sdatas_fixed.append(sdata_fixed)
return sdatas_fixed
36 changes: 36 additions & 0 deletions tests/core/test_concatenate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the tests for concatenate are in test_spatialdata_operations.py, I will move this code there.


import spatialdata as sd
from spatialdata.datasets import blobs


@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False])
def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name):
blob1 = blobs()
blob2 = blobs()

sdata_keys = ["blob1", "blob2"]
sdata = sd.concatenate(
dict(zip(sdata_keys, [blob1, blob2], strict=True)),
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)

expected_images = ["blobs_image", "blobs_multiscale_image"]
expected_labels = ["blobs_labels", "blobs_multiscale_labels"]
expected_points = ["blobs_points"]
expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]

expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images]
expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels]
expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points]
expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes]

assert set(sdata.images.keys()) == set(expected_suffixed_images)
assert set(sdata.labels.keys()) == set(expected_suffixed_labels)
assert set(sdata.points.keys()) == set(expected_suffixed_points)
assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes)

if merge_coordinate_systems_on_name:
assert sdata.coordinate_systems == ["global"]
else:
assert sdata.coordinate_systems == ["global-blob1", "global-blob2"]
Loading