Skip to content
45 changes: 34 additions & 11 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import SpatialElement, TableModel, get_table_keys
from spatialdata.models import TableModel, get_table_keys
from spatialdata.transformations import (
get_transformation,
remove_transformation,
set_transformation,
)

__all__ = [
"concatenate",
Expand Down Expand Up @@ -81,7 +86,8 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
merge_coordinate_systems_on_name: bool = False,
attrs_merge: (StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None) = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -110,6 +116,8 @@ def concatenate(
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
merge_coordinate_systems_on_name
Whether to keep coordinate system names unchanged (True) or add suffixes (False).
attrs_merge
How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html)
kwargs
Expand Down Expand Up @@ -141,6 +149,7 @@ def concatenate(
rename_tables=not concatenate_tables,
rename_obs_names=obs_names_make_unique and concatenate_tables,
modify_tables_inplace=modify_tables_inplace,
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)

ERROR_STR = (
Expand Down Expand Up @@ -222,12 +231,27 @@ def _fix_ensure_unique_element_names(
rename_tables: bool,
rename_obs_names: bool,
modify_tables_inplace: bool,
merge_coordinate_systems_on_name: bool,
) -> list[SpatialData]:
elements_by_sdata: list[dict[str, SpatialElement]] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactoring is fine

tables_by_sdata: list[dict[str, AnnData]] = []
sdatas_fixed = []
for suffix, sdata in sdatas.items():
elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()}
elements_by_sdata.append(elements)
# Create new elements dictionary with suffixed names
elements = {}
for _, name, el in sdata.gen_spatial_elements():
new_element_name = f"{name}-{suffix}"
if not merge_coordinate_systems_on_name:
# Set new transformations with suffixed coordinate system names
transformations = get_transformation(el, get_all=True)
assert isinstance(transformations, dict)

remove_transformation(el, remove_all=True)
for cs, t in transformations.items():
new_cs = f"{cs}-{suffix}"
set_transformation(el, t, to_coordinate_system=new_cs)

elements[new_element_name] = el

# Handle tables with suffix
tables = {}
for name, table in sdata.tables.items():
if not modify_tables_inplace:
Expand All @@ -251,9 +275,8 @@ def _fix_ensure_unique_element_names(
# fix the table name
new_name = f"{name}-{suffix}" if rename_tables else name
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True):
sdata = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata)

# Create new SpatialData object with suffixed elements and tables
sdata_fixed = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata_fixed)
return sdatas_fixed
34 changes: 34 additions & 0 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,40 @@ def _n_elements(sdata: SpatialData) -> int:
assert "blobs_image-sample" in c.images


@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False])
def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name):
blob1 = blobs()
blob2 = blobs()

sdata_keys = ["blob1", "blob2"]
sdata = concatenate(
dict(zip(sdata_keys, [blob1, blob2], strict=True)),
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)

if merge_coordinate_systems_on_name:
assert set(sdata.coordinate_systems) == {"global"}
else:
assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"}

# extra checks not specific to this test, we could remove them or leave them just
# in case
expected_images = ["blobs_image", "blobs_multiscale_image"]
expected_labels = ["blobs_labels", "blobs_multiscale_labels"]
expected_points = ["blobs_points"]
expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]

expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images]
expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels]
expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points]
expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes]

assert set(sdata.images.keys()) == set(expected_suffixed_images)
assert set(sdata.labels.keys()) == set(expected_suffixed_labels)
assert set(sdata.points.keys()) == set(expected_suffixed_points)
assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes)


def test_locate_spatial_element(full_sdata: SpatialData) -> None:
assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d"
im = full_sdata.images["image2d"]
Expand Down
Loading