diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index da1865d8..87a9a032 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -13,7 +13,12 @@ from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData -from spatialdata.models import SpatialElement, TableModel, get_table_keys +from spatialdata.models import TableModel, get_table_keys +from spatialdata.transformations import ( + get_transformation, + remove_transformation, + set_transformation, +) __all__ = [ "concatenate", @@ -81,7 +86,8 @@ def concatenate( concatenate_tables: bool = False, obs_names_make_unique: bool = True, modify_tables_inplace: bool = False, - attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None, + merge_coordinate_systems_on_name: bool = False, + attrs_merge: (StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None) = None, **kwargs: Any, ) -> SpatialData: """ @@ -110,6 +116,8 @@ def concatenate( modify_tables_inplace Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables will be copied before modification. Copying is enabled by default but can be disabled for performance reasons. + merge_coordinate_systems_on_name + Whether to keep coordinate system names unchanged (True) or add suffixes (False). attrs_merge How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html) kwargs @@ -141,6 +149,7 @@ def concatenate( rename_tables=not concatenate_tables, rename_obs_names=obs_names_make_unique and concatenate_tables, modify_tables_inplace=modify_tables_inplace, + merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, ) ERROR_STR = ( @@ -222,12 +231,27 @@ def _fix_ensure_unique_element_names( rename_tables: bool, rename_obs_names: bool, modify_tables_inplace: bool, + merge_coordinate_systems_on_name: bool, ) -> list[SpatialData]: - elements_by_sdata: list[dict[str, SpatialElement]] = [] - tables_by_sdata: list[dict[str, AnnData]] = [] + sdatas_fixed = [] for suffix, sdata in sdatas.items(): - elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()} - elements_by_sdata.append(elements) + # Create new elements dictionary with suffixed names + elements = {} + for _, name, el in sdata.gen_spatial_elements(): + new_element_name = f"{name}-{suffix}" + if not merge_coordinate_systems_on_name: + # Set new transformations with suffixed coordinate system names + transformations = get_transformation(el, get_all=True) + assert isinstance(transformations, dict) + + remove_transformation(el, remove_all=True) + for cs, t in transformations.items(): + new_cs = f"{cs}-{suffix}" + set_transformation(el, t, to_coordinate_system=new_cs) + + elements[new_element_name] = el + + # Handle tables with suffix tables = {} for name, table in sdata.tables.items(): if not modify_tables_inplace: @@ -251,9 +275,8 @@ def _fix_ensure_unique_element_names( # fix the table name new_name = f"{name}-{suffix}" if rename_tables else name tables[new_name] = table - tables_by_sdata.append(tables) - sdatas_fixed = [] - for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True): - sdata = SpatialData.init_from_elements(elements, tables=tables) - sdatas_fixed.append(sdata) + + # Create new SpatialData object with suffixed elements and tables + sdata_fixed = SpatialData.init_from_elements(elements, tables=tables) + sdatas_fixed.append(sdata_fixed) return sdatas_fixed diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index dd765029..9124c60e 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -471,6 +471,40 @@ def _n_elements(sdata: SpatialData) -> int: assert "blobs_image-sample" in c.images +@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False]) +def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name): + blob1 = blobs() + blob2 = blobs() + + sdata_keys = ["blob1", "blob2"] + sdata = concatenate( + dict(zip(sdata_keys, [blob1, blob2], strict=True)), + merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, + ) + + if merge_coordinate_systems_on_name: + assert set(sdata.coordinate_systems) == {"global"} + else: + assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"} + + # extra checks not specific to this test, we could remove them or leave them just + # in case + expected_images = ["blobs_image", "blobs_multiscale_image"] + expected_labels = ["blobs_labels", "blobs_multiscale_labels"] + expected_points = ["blobs_points"] + expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"] + + expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images] + expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels] + expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points] + expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes] + + assert set(sdata.images.keys()) == set(expected_suffixed_images) + assert set(sdata.labels.keys()) == set(expected_suffixed_labels) + assert set(sdata.points.keys()) == set(expected_suffixed_points) + assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) + + def test_locate_spatial_element(full_sdata: SpatialData) -> None: assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"]