From a78d460a81d38ce0928e73871ae6058b3800a84d Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Wed, 2 Apr 2025 11:00:40 +0200 Subject: [PATCH 1/7] added parameter to optionally merge cs --- src/spatialdata/_core/concatenate.py | 99 +++++++++++++++++++++------- tests/core/test_concatenate.py | 46 +++++++++++++ 2 files changed, 123 insertions(+), 22 deletions(-) create mode 100644 tests/core/test_concatenate.py diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index da1865d89..497a642ae 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -14,6 +14,11 @@ from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement, TableModel, get_table_keys +from spatialdata.transformations import ( + get_transformation, + remove_transformation, + set_transformation, +) __all__ = [ "concatenate", @@ -30,29 +35,43 @@ def _concatenate_tables( if not all(TableModel.ATTRS_KEY in table.uns for table in tables): raise ValueError("Not all tables are annotating a spatial element") - region_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables] - instance_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables] - regions = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables] + region_keys = [ + table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables + ] + instance_keys = [ + table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables + ] + regions = [ + table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables + ] if len(set(region_keys)) == 1: region_key = list(region_keys)[0] else: if region_key is None: - raise ValueError("`region_key` must be specified if tables have different region keys") + raise ValueError( + "`region_key` must be specified if tables have different region keys" + ) # get unique regions from list of lists or str regions_unique = list(chain(*[[i] if isinstance(i, str) else i for i in regions])) if len(set(regions_unique)) != len(regions_unique): - raise ValueError(f"Two or more tables seems to annotate regions with the same name: {regions_unique}") + raise ValueError( + f"Two or more tables seems to annotate regions with the same name: {regions_unique}" + ) if len(set(instance_keys)) == 1: instance_key = list(instance_keys)[0] else: if instance_key is None: - raise ValueError("`instance_key` must be specified if tables have different instance keys") + raise ValueError( + "`instance_key` must be specified if tables have different instance keys" + ) tables_l = [] - for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables, strict=True): + for table_region_key, table_instance_key, table in zip( + region_keys, instance_keys, tables, strict=True + ): rename_dict = {} if table_region_key != region_key: rename_dict[table_region_key] = region_key @@ -81,7 +100,10 @@ def concatenate( concatenate_tables: bool = False, obs_names_make_unique: bool = True, modify_tables_inplace: bool = False, - attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None, + merge_coordinate_systems_on_name: bool = False, + attrs_merge: ( + StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None + ) = None, **kwargs: Any, ) -> SpatialData: """ @@ -141,6 +163,7 @@ def concatenate( rename_tables=not concatenate_tables, rename_obs_names=obs_names_make_unique and concatenate_tables, modify_tables_inplace=modify_tables_inplace, + merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, ) ERROR_STR = ( @@ -188,7 +211,9 @@ def concatenate( for sdata in sdatas: for k, v in sdata.tables.items(): if k in common_keys and merged_tables.get(k) is not None: - merged_tables[k] = _concatenate_tables([merged_tables[k], v], region_key, instance_key, **kwargs) + merged_tables[k] = _concatenate_tables( + [merged_tables[k], v], region_key, instance_key, **kwargs + ) else: merged_tables[k] = v @@ -209,11 +234,15 @@ def concatenate( return sdata -def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData: +def _filter_table_in_coordinate_systems( + table: AnnData, coordinate_systems: list[str] +) -> AnnData: table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] new_table = table[table.obs[region_key].isin(coordinate_systems)].copy() - new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist() + new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = ( + new_table.obs[region_key].unique().tolist() + ) return new_table @@ -222,12 +251,16 @@ def _fix_ensure_unique_element_names( rename_tables: bool, rename_obs_names: bool, modify_tables_inplace: bool, + merge_coordinate_systems_on_name: bool, ) -> list[SpatialData]: - elements_by_sdata: list[dict[str, SpatialElement]] = [] - tables_by_sdata: list[dict[str, AnnData]] = [] + sdatas_fixed = [] for suffix, sdata in sdatas.items(): - elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()} - elements_by_sdata.append(elements) + # Create new elements dictionary with suffixed names + elements = { + f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements() + } + + # Handle tables with suffix tables = {} for name, table in sdata.tables.items(): if not modify_tables_inplace: @@ -235,7 +268,9 @@ def _fix_ensure_unique_element_names( # fix the region_key column region, region_key, _ = get_table_keys(table) - table.obs[region_key] = (table.obs[region_key].astype("str") + f"-{suffix}").astype("category") + table.obs[region_key] = ( + table.obs[region_key].astype("str") + f"-{suffix}" + ).astype("category") new_region: str | list[str] if isinstance(region, str): new_region = f"{region}-{suffix}" @@ -246,14 +281,34 @@ def _fix_ensure_unique_element_names( # fix the obs names if rename_obs_names: - table.obs.index = table.obs.index.to_series().apply(lambda x, suffix=suffix: f"{x}-{suffix}") + table.obs.index = table.obs.index.to_series().apply( + lambda x, suffix=suffix: f"{x}-{suffix}" + ) # fix the table name new_name = f"{name}-{suffix}" if rename_tables else name tables[new_name] = table - tables_by_sdata.append(tables) - sdatas_fixed = [] - for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True): - sdata = SpatialData.init_from_elements(elements, tables=tables) - sdatas_fixed.append(sdata) + + # Create new SpatialData object with suffixed elements and tables + sdata_fixed = SpatialData.init_from_elements(elements, tables=tables) + + # Handle coordinate systems and transformations + for element_name, element in elements.items(): + # Get the original element from the input sdata + original_name = element_name.replace(f"-{suffix}", "") + original_element = sdata.get(original_name) + + # Get transformations from original element + transformations = get_transformation(original_element, get_all=True) + assert isinstance(transformations, dict) + + # Remove any existing transformations from the new element + remove_transformation(element, remove_all=True) + + # Set new transformations with suffixed coordinate system names + for cs, t in transformations.items(): + new_cs = cs if merge_coordinate_systems_on_name else f"{cs}-{suffix}" + set_transformation(element, t, to_coordinate_system=new_cs) + + sdatas_fixed.append(sdata_fixed) return sdatas_fixed diff --git a/tests/core/test_concatenate.py b/tests/core/test_concatenate.py new file mode 100644 index 000000000..ccd294f8c --- /dev/null +++ b/tests/core/test_concatenate.py @@ -0,0 +1,46 @@ +import pandas as pd +import pytest + +import spatialdata as sd +from spatialdata.datasets import blobs + + +@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False]) +def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name): + blob1 = blobs() + blob2 = blobs() + + sdata_keys = ["blob1", "blob2"] + sdata = sd.concatenate( + dict(zip(sdata_keys, [blob1, blob2])), + merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, + ) + + expected_images = ["blobs_image", "blobs_multiscale_image"] + expected_labels = ["blobs_labels", "blobs_multiscale_labels"] + expected_points = ["blobs_points"] + expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"] + + expected_suffixed_images = [ + f"{name}-{key}" for key in sdata_keys for name in expected_images + ] + expected_suffixed_labels = [ + f"{name}-{key}" for key in sdata_keys for name in expected_labels + ] + expected_suffixed_points = [ + f"{name}-{key}" for key in sdata_keys for name in expected_points + ] + expected_suffixed_shapes = [ + f"{name}-{key}" for key in sdata_keys for name in expected_shapes + ] + + assert set(sdata.images.keys()) == set(expected_suffixed_images) + assert set(sdata.labels.keys()) == set(expected_suffixed_labels) + assert set(sdata.points.keys()) == set(expected_suffixed_points) + assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) + + + if merge_coordinate_systems_on_name: + assert sdata.coordinate_systems == ["global"] + else: + assert sdata.coordinate_systems == ["global-blob1", "global-blob2"] From 19ef82ac767bdc94020817664cc93b2531f3bba7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:02:29 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/concatenate.py | 58 +++++++--------------------- tests/core/test_concatenate.py | 18 ++------- 2 files changed, 19 insertions(+), 57 deletions(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 497a642ae..7ae3e1a77 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -13,7 +13,7 @@ from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData -from spatialdata.models import SpatialElement, TableModel, get_table_keys +from spatialdata.models import TableModel, get_table_keys from spatialdata.transformations import ( get_transformation, remove_transformation, @@ -35,43 +35,29 @@ def _concatenate_tables( if not all(TableModel.ATTRS_KEY in table.uns for table in tables): raise ValueError("Not all tables are annotating a spatial element") - region_keys = [ - table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables - ] - instance_keys = [ - table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables - ] - regions = [ - table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables - ] + region_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables] + instance_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables] + regions = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables] if len(set(region_keys)) == 1: region_key = list(region_keys)[0] else: if region_key is None: - raise ValueError( - "`region_key` must be specified if tables have different region keys" - ) + raise ValueError("`region_key` must be specified if tables have different region keys") # get unique regions from list of lists or str regions_unique = list(chain(*[[i] if isinstance(i, str) else i for i in regions])) if len(set(regions_unique)) != len(regions_unique): - raise ValueError( - f"Two or more tables seems to annotate regions with the same name: {regions_unique}" - ) + raise ValueError(f"Two or more tables seems to annotate regions with the same name: {regions_unique}") if len(set(instance_keys)) == 1: instance_key = list(instance_keys)[0] else: if instance_key is None: - raise ValueError( - "`instance_key` must be specified if tables have different instance keys" - ) + raise ValueError("`instance_key` must be specified if tables have different instance keys") tables_l = [] - for table_region_key, table_instance_key, table in zip( - region_keys, instance_keys, tables, strict=True - ): + for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables, strict=True): rename_dict = {} if table_region_key != region_key: rename_dict[table_region_key] = region_key @@ -101,9 +87,7 @@ def concatenate( obs_names_make_unique: bool = True, modify_tables_inplace: bool = False, merge_coordinate_systems_on_name: bool = False, - attrs_merge: ( - StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None - ) = None, + attrs_merge: (StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None) = None, **kwargs: Any, ) -> SpatialData: """ @@ -211,9 +195,7 @@ def concatenate( for sdata in sdatas: for k, v in sdata.tables.items(): if k in common_keys and merged_tables.get(k) is not None: - merged_tables[k] = _concatenate_tables( - [merged_tables[k], v], region_key, instance_key, **kwargs - ) + merged_tables[k] = _concatenate_tables([merged_tables[k], v], region_key, instance_key, **kwargs) else: merged_tables[k] = v @@ -234,15 +216,11 @@ def concatenate( return sdata -def _filter_table_in_coordinate_systems( - table: AnnData, coordinate_systems: list[str] -) -> AnnData: +def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData: table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] new_table = table[table.obs[region_key].isin(coordinate_systems)].copy() - new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = ( - new_table.obs[region_key].unique().tolist() - ) + new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist() return new_table @@ -256,9 +234,7 @@ def _fix_ensure_unique_element_names( sdatas_fixed = [] for suffix, sdata in sdatas.items(): # Create new elements dictionary with suffixed names - elements = { - f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements() - } + elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()} # Handle tables with suffix tables = {} @@ -268,9 +244,7 @@ def _fix_ensure_unique_element_names( # fix the region_key column region, region_key, _ = get_table_keys(table) - table.obs[region_key] = ( - table.obs[region_key].astype("str") + f"-{suffix}" - ).astype("category") + table.obs[region_key] = (table.obs[region_key].astype("str") + f"-{suffix}").astype("category") new_region: str | list[str] if isinstance(region, str): new_region = f"{region}-{suffix}" @@ -281,9 +255,7 @@ def _fix_ensure_unique_element_names( # fix the obs names if rename_obs_names: - table.obs.index = table.obs.index.to_series().apply( - lambda x, suffix=suffix: f"{x}-{suffix}" - ) + table.obs.index = table.obs.index.to_series().apply(lambda x, suffix=suffix: f"{x}-{suffix}") # fix the table name new_name = f"{name}-{suffix}" if rename_tables else name diff --git a/tests/core/test_concatenate.py b/tests/core/test_concatenate.py index ccd294f8c..be2d22105 100644 --- a/tests/core/test_concatenate.py +++ b/tests/core/test_concatenate.py @@ -1,4 +1,3 @@ -import pandas as pd import pytest import spatialdata as sd @@ -21,25 +20,16 @@ def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_o expected_points = ["blobs_points"] expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"] - expected_suffixed_images = [ - f"{name}-{key}" for key in sdata_keys for name in expected_images - ] - expected_suffixed_labels = [ - f"{name}-{key}" for key in sdata_keys for name in expected_labels - ] - expected_suffixed_points = [ - f"{name}-{key}" for key in sdata_keys for name in expected_points - ] - expected_suffixed_shapes = [ - f"{name}-{key}" for key in sdata_keys for name in expected_shapes - ] + expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images] + expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels] + expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points] + expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes] assert set(sdata.images.keys()) == set(expected_suffixed_images) assert set(sdata.labels.keys()) == set(expected_suffixed_labels) assert set(sdata.points.keys()) == set(expected_suffixed_points) assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) - if merge_coordinate_systems_on_name: assert sdata.coordinate_systems == ["global"] else: From 0cc33d981df9f3f922e729388154cf1a289e6741 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Wed, 2 Apr 2025 11:28:24 +0200 Subject: [PATCH 3/7] fixed ruff --- tests/core/test_concatenate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_concatenate.py b/tests/core/test_concatenate.py index be2d22105..2bc68a9b9 100644 --- a/tests/core/test_concatenate.py +++ b/tests/core/test_concatenate.py @@ -11,7 +11,7 @@ def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_o sdata_keys = ["blob1", "blob2"] sdata = sd.concatenate( - dict(zip(sdata_keys, [blob1, blob2])), + dict(zip(sdata_keys, [blob1, blob2], strict=True)), merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, ) From 604511eb9eaa7151979871c81f6702436e5d4e1e Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sun, 18 May 2025 16:36:04 +0100 Subject: [PATCH 4/7] added helper function to sanitize table objects --- src/spatialdata/_utils.py | 138 +++++++++++++++++++++++ tests/utils/test_sanitize.py | 212 +++++++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 tests/utils/test_sanitize.py diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 61f5a52c7..1d9ccee82 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -346,3 +346,141 @@ def _check_match_length_channels_c_dim( f" with length {c_length}." ) return c_coords + + +def sanitize_name(name: str, is_dataframe_column: bool = False) -> str: + """ + Sanitize a name to comply with SpatialData naming rules. + + This function converts invalid names into valid ones by: + 1. Converting to string if not already + 2. Removing invalid characters + 3. Handling special cases like "__" prefix + 4. Ensuring the name is not empty + 5. Handling special cases for dataframe columns + + Parameters + ---------- + name + The name to sanitize + is_dataframe_column + Whether this name is for a dataframe column (additional restrictions apply) + + Returns + ------- + A sanitized version of the name that complies with SpatialData naming rules. + + Examples + -------- + >>> sanitize_name("my@invalid#name") + 'my_invalid_name' + >>> sanitize_name("__private") + 'private' + >>> sanitize_name("_index", is_dataframe_column=True) + 'index' + """ + # Convert to string if not already + name = str(name) + + # Handle empty string case + if not name: + return "unnamed" + + # Handle special cases + if name == "." or name == "..": + return "unnamed" + + # Remove "__" prefix if present + if name.startswith("__"): + name = name[2:] + + # Replace invalid characters with underscore + # Keep only alphanumeric, underscore, dot, and hyphen + sanitized = "" + for char in name: + if char.isalnum() or char in "_-.": + sanitized += char + else: + sanitized += "_" + + # Remove leading underscores but keep trailing ones + sanitized = sanitized.lstrip("_") + + # Ensure we don't end up with an empty string after sanitization + if not sanitized: + return "unnamed" + + return sanitized + + +def sanitize_table(data: AnnData, inplace: bool = True) -> AnnData | None: + """ + Sanitize all keys in an AnnData table to comply with SpatialData naming rules. + + This function sanitizes all keys in obs, var, obsm, obsp, varm, varp, uns, and layers + while maintaining case-insensitive uniqueness. It can either modify the table in-place + or return a new sanitized copy. + + Parameters + ---------- + data + The AnnData table to sanitize + inplace + Whether to modify the table in-place or return a new copy + + Returns + ------- + If inplace is False, returns a new AnnData object with sanitized keys. + If inplace is True, returns None as the original object is modified. + + Examples + -------- + >>> import anndata as ad + >>> adata = ad.AnnData(obs=pd.DataFrame({"@invalid#": [1, 2]})) + >>> # Create a new sanitized copy + >>> sanitized = sanitize_table(adata) + >>> print(sanitized.obs.columns) + Index(['invalid_'], dtype='object') + >>> # Or modify in-place + >>> sanitize_table(adata, inplace=True) + >>> print(adata.obs.columns) + Index(['invalid_'], dtype='object') + """ + import copy + from collections import defaultdict + + # Create a deep copy if not modifying in-place + sanitized = data if inplace else copy.deepcopy(data) + + # Track used names to maintain case-insensitive uniqueness + used_names: dict[str, set[str]] = defaultdict(set) + + def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> str: + base_name = sanitize_name(name, is_dataframe_column) + normalized_base = base_name.lower() + + # If this exact name is already used, add a number + if normalized_base in {n.lower() for n in used_names[attr]}: + counter = 1 + while f"{base_name}_{counter}".lower() in {n.lower() for n in used_names[attr]}: + counter += 1 + base_name = f"{base_name}_{counter}" + + used_names[attr].add(base_name) + return base_name + + # Handle obs and var (dataframe columns) + for attr in ("obs", "var"): + df = getattr(sanitized, attr) + new_columns = {old: get_unique_name(old, attr, is_dataframe_column=True) for old in df.columns} + df.rename(columns=new_columns, inplace=True) + + # Handle other attributes + for attr in ("obsm", "obsp", "varm", "varp", "uns", "layers"): + d = getattr(sanitized, attr) + new_keys = {old: get_unique_name(old, attr) for old in d} + # Create new dictionary with sanitized keys + new_dict = {new_keys[old]: value for old, value in d.items()} + setattr(sanitized, attr, new_dict) + + return None if inplace else sanitized diff --git a/tests/utils/test_sanitize.py b/tests/utils/test_sanitize.py new file mode 100644 index 000000000..b567cc53d --- /dev/null +++ b/tests/utils/test_sanitize.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +from spatialdata import SpatialData +from spatialdata._utils import sanitize_name, sanitize_table + + +@pytest.fixture +def invalid_table() -> AnnData: + """AnnData with invalid obs column names to test basic sanitization.""" + return AnnData( + obs=pd.DataFrame( + { + "@invalid#": [1, 2], + "valid_name": [3, 4], + "__private": [5, 6], + } + ) + ) + + +@pytest.fixture +def invalid_table_with_index() -> AnnData: + """AnnData with a name requiring whitespace→underscore and a dataframe index column.""" + return AnnData( + obs=pd.DataFrame( + { + "invalid name": [1, 2], + "_index": [3, 4], + } + ) + ) + + +@pytest.fixture +def sdata_sanitized_tables(invalid_table, invalid_table_with_index) -> SpatialData: + """SpatialData built from sanitized copies of the invalid tables.""" + table1 = invalid_table.copy() + table2 = invalid_table_with_index.copy() + sanitize_table(table1) + sanitize_table(table2) + return SpatialData(tables={"table1": table1, "table2": table2}) + + +# ----------------------------------------------------------------------------- +# sanitize_name tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("valid_name", "valid_name"), + ("valid-name", "valid-name"), + ("valid.name", "valid.name"), + ("invalid@name", "invalid_name"), + ("invalid#name", "invalid_name"), + ("invalid name", "invalid_name"), + ("", "unnamed"), + (".", "unnamed"), + ("..", "unnamed"), + ("__private", "private"), + ], +) +def test_sanitize_name_strips_special_chars(raw, expected): + assert sanitize_name(raw) == expected + + +@pytest.mark.parametrize( + "raw,is_df_col,expected", + [ + ("_index", True, "index"), + ("_index", False, "index"), + ("valid@column", True, "valid_column"), + ("__private", True, "private"), + ], +) +def test_sanitize_name_dataframe_column(raw, is_df_col, expected): + assert sanitize_name(raw, is_dataframe_column=is_df_col) == expected + + +# ----------------------------------------------------------------------------- +# sanitize_table basic behaviors +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_basic_columns(invalid_table, invalid_table_with_index): + ad1 = sanitize_table(invalid_table, inplace=False) + assert isinstance(ad1, AnnData) + assert list(ad1.obs.columns) == ["invalid_", "valid_name", "private"] + + ad2 = sanitize_table(invalid_table_with_index, inplace=False) + assert list(ad2.obs.columns) == ["invalid_name", "index"] + + # original fixture remains unchanged + assert list(invalid_table.obs.columns) == ["@invalid#", "valid_name", "__private"] + + +def test_sanitize_table_inplace_copy(invalid_table): + ad = invalid_table.copy() + sanitize_table(ad) # inplace=True is now default + assert list(ad.obs.columns) == ["invalid_", "valid_name", "private"] + + +def test_sanitize_table_case_insensitive_collisions(): + obs = pd.DataFrame( + { + "Column1": [1, 2], + "column1": [3, 4], + "COLUMN1": [5, 6], + } + ) + ad = AnnData(obs=obs) + sanitized = sanitize_table(ad, inplace=False) + cols = list(sanitized.obs.columns) + assert sorted(cols) == sorted(["Column1", "column1_1", "COLUMN1_2"]) + + +def test_sanitize_table_whitespace_collision(): + """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'.""" + obs = pd.DataFrame({"a b": [1], "a_b": [2]}) + ad = AnnData(obs=obs) + sanitized = sanitize_table(ad, inplace=False) + cols = list(sanitized.obs.columns) + assert "a_b" in cols + assert "a_b_1" in cols + + +# ----------------------------------------------------------------------------- +# sanitize_table attribute‐specific tests +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_obs_and_obs_columns(): + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.obs.columns) == ["col"] + + +def test_sanitize_table_obsm_and_obsp(): + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad.obsm["@col"] = np.array([[1, 2], [3, 4]]) + ad.obsp["bad name"] = np.array([[1, 2], [3, 4]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.obsm.keys()) == ["col"] + assert list(sanitized.obsp.keys()) == ["bad_name"] + + +def test_sanitize_table_varm_and_varp(): + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad.varm["__priv"] = np.array([[1, 2], [3, 4]]) + ad.varp["_index"] = np.array([[1, 2], [3, 4]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.varm.keys()) == ["priv"] + assert list(sanitized.varp.keys()) == ["index"] + + +def test_sanitize_table_uns_and_layers(): + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad.uns["bad@key"] = "val" + ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.uns.keys()) == ["bad_key"] + assert list(sanitized.layers.keys()) == ["bad_layer"] + + +def test_sanitize_table_empty_returns_empty(): + ad = AnnData() + sanitized = sanitize_table(ad, inplace=False) + assert isinstance(sanitized, AnnData) + assert sanitized.obs.empty + assert sanitized.var.empty + + +def test_sanitize_table_preserves_underlying_data(): + ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]})) + ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]]) + ad.uns["invalid@key"] = "value" + sanitized = sanitize_table(ad, inplace=False) + assert sanitized.obs["invalid_"].tolist() == [1, 2] + assert sanitized.obs["valid"].tolist() == [3, 4] + assert np.array_equal(sanitized.obsm["invalid_"], np.array([[1, 2], [3, 4]])) + assert sanitized.uns["invalid_key"] == "value" + + +# ----------------------------------------------------------------------------- +# SpatialData integration +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_in_spatialdata_sanitized_fixture(sdata_sanitized_tables): + t1 = sdata_sanitized_tables.tables["table1"] + t2 = sdata_sanitized_tables.tables["table2"] + assert list(t1.obs.columns) == ["invalid_", "valid_name", "private"] + assert list(t2.obs.columns) == ["invalid_name", "index"] + + +def test_spatialdata_retains_other_elements(full_sdata, sdata_sanitized_tables): + # Add another sanitized table into an existing full_sdata + tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]})) + sanitize_table(tbl) + full_sdata.tables["new_table"] = tbl + + # Verify columns and presence of other SpatialData attributes + assert list(full_sdata.tables["new_table"].obs.columns) == ["foo_", "bar"] + assert "image2d" in full_sdata.images + assert "labels2d" in full_sdata.labels + assert "points_0" in full_sdata.points From d5425646096636b806c9ed9cbcdc958578ee94a8 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sun, 18 May 2025 16:42:47 +0100 Subject: [PATCH 5/7] Revert "added helper function to sanitize table objects" This reverts commit 604511eb9eaa7151979871c81f6702436e5d4e1e. --- src/spatialdata/_utils.py | 138 ----------------------- tests/utils/test_sanitize.py | 212 ----------------------------------- 2 files changed, 350 deletions(-) delete mode 100644 tests/utils/test_sanitize.py diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 1d9ccee82..61f5a52c7 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -346,141 +346,3 @@ def _check_match_length_channels_c_dim( f" with length {c_length}." ) return c_coords - - -def sanitize_name(name: str, is_dataframe_column: bool = False) -> str: - """ - Sanitize a name to comply with SpatialData naming rules. - - This function converts invalid names into valid ones by: - 1. Converting to string if not already - 2. Removing invalid characters - 3. Handling special cases like "__" prefix - 4. Ensuring the name is not empty - 5. Handling special cases for dataframe columns - - Parameters - ---------- - name - The name to sanitize - is_dataframe_column - Whether this name is for a dataframe column (additional restrictions apply) - - Returns - ------- - A sanitized version of the name that complies with SpatialData naming rules. - - Examples - -------- - >>> sanitize_name("my@invalid#name") - 'my_invalid_name' - >>> sanitize_name("__private") - 'private' - >>> sanitize_name("_index", is_dataframe_column=True) - 'index' - """ - # Convert to string if not already - name = str(name) - - # Handle empty string case - if not name: - return "unnamed" - - # Handle special cases - if name == "." or name == "..": - return "unnamed" - - # Remove "__" prefix if present - if name.startswith("__"): - name = name[2:] - - # Replace invalid characters with underscore - # Keep only alphanumeric, underscore, dot, and hyphen - sanitized = "" - for char in name: - if char.isalnum() or char in "_-.": - sanitized += char - else: - sanitized += "_" - - # Remove leading underscores but keep trailing ones - sanitized = sanitized.lstrip("_") - - # Ensure we don't end up with an empty string after sanitization - if not sanitized: - return "unnamed" - - return sanitized - - -def sanitize_table(data: AnnData, inplace: bool = True) -> AnnData | None: - """ - Sanitize all keys in an AnnData table to comply with SpatialData naming rules. - - This function sanitizes all keys in obs, var, obsm, obsp, varm, varp, uns, and layers - while maintaining case-insensitive uniqueness. It can either modify the table in-place - or return a new sanitized copy. - - Parameters - ---------- - data - The AnnData table to sanitize - inplace - Whether to modify the table in-place or return a new copy - - Returns - ------- - If inplace is False, returns a new AnnData object with sanitized keys. - If inplace is True, returns None as the original object is modified. - - Examples - -------- - >>> import anndata as ad - >>> adata = ad.AnnData(obs=pd.DataFrame({"@invalid#": [1, 2]})) - >>> # Create a new sanitized copy - >>> sanitized = sanitize_table(adata) - >>> print(sanitized.obs.columns) - Index(['invalid_'], dtype='object') - >>> # Or modify in-place - >>> sanitize_table(adata, inplace=True) - >>> print(adata.obs.columns) - Index(['invalid_'], dtype='object') - """ - import copy - from collections import defaultdict - - # Create a deep copy if not modifying in-place - sanitized = data if inplace else copy.deepcopy(data) - - # Track used names to maintain case-insensitive uniqueness - used_names: dict[str, set[str]] = defaultdict(set) - - def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> str: - base_name = sanitize_name(name, is_dataframe_column) - normalized_base = base_name.lower() - - # If this exact name is already used, add a number - if normalized_base in {n.lower() for n in used_names[attr]}: - counter = 1 - while f"{base_name}_{counter}".lower() in {n.lower() for n in used_names[attr]}: - counter += 1 - base_name = f"{base_name}_{counter}" - - used_names[attr].add(base_name) - return base_name - - # Handle obs and var (dataframe columns) - for attr in ("obs", "var"): - df = getattr(sanitized, attr) - new_columns = {old: get_unique_name(old, attr, is_dataframe_column=True) for old in df.columns} - df.rename(columns=new_columns, inplace=True) - - # Handle other attributes - for attr in ("obsm", "obsp", "varm", "varp", "uns", "layers"): - d = getattr(sanitized, attr) - new_keys = {old: get_unique_name(old, attr) for old in d} - # Create new dictionary with sanitized keys - new_dict = {new_keys[old]: value for old, value in d.items()} - setattr(sanitized, attr, new_dict) - - return None if inplace else sanitized diff --git a/tests/utils/test_sanitize.py b/tests/utils/test_sanitize.py deleted file mode 100644 index b567cc53d..000000000 --- a/tests/utils/test_sanitize.py +++ /dev/null @@ -1,212 +0,0 @@ -from __future__ import annotations - -import numpy as np -import pandas as pd -import pytest -from anndata import AnnData - -from spatialdata import SpatialData -from spatialdata._utils import sanitize_name, sanitize_table - - -@pytest.fixture -def invalid_table() -> AnnData: - """AnnData with invalid obs column names to test basic sanitization.""" - return AnnData( - obs=pd.DataFrame( - { - "@invalid#": [1, 2], - "valid_name": [3, 4], - "__private": [5, 6], - } - ) - ) - - -@pytest.fixture -def invalid_table_with_index() -> AnnData: - """AnnData with a name requiring whitespace→underscore and a dataframe index column.""" - return AnnData( - obs=pd.DataFrame( - { - "invalid name": [1, 2], - "_index": [3, 4], - } - ) - ) - - -@pytest.fixture -def sdata_sanitized_tables(invalid_table, invalid_table_with_index) -> SpatialData: - """SpatialData built from sanitized copies of the invalid tables.""" - table1 = invalid_table.copy() - table2 = invalid_table_with_index.copy() - sanitize_table(table1) - sanitize_table(table2) - return SpatialData(tables={"table1": table1, "table2": table2}) - - -# ----------------------------------------------------------------------------- -# sanitize_name tests -# ----------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "raw,expected", - [ - ("valid_name", "valid_name"), - ("valid-name", "valid-name"), - ("valid.name", "valid.name"), - ("invalid@name", "invalid_name"), - ("invalid#name", "invalid_name"), - ("invalid name", "invalid_name"), - ("", "unnamed"), - (".", "unnamed"), - ("..", "unnamed"), - ("__private", "private"), - ], -) -def test_sanitize_name_strips_special_chars(raw, expected): - assert sanitize_name(raw) == expected - - -@pytest.mark.parametrize( - "raw,is_df_col,expected", - [ - ("_index", True, "index"), - ("_index", False, "index"), - ("valid@column", True, "valid_column"), - ("__private", True, "private"), - ], -) -def test_sanitize_name_dataframe_column(raw, is_df_col, expected): - assert sanitize_name(raw, is_dataframe_column=is_df_col) == expected - - -# ----------------------------------------------------------------------------- -# sanitize_table basic behaviors -# ----------------------------------------------------------------------------- - - -def test_sanitize_table_basic_columns(invalid_table, invalid_table_with_index): - ad1 = sanitize_table(invalid_table, inplace=False) - assert isinstance(ad1, AnnData) - assert list(ad1.obs.columns) == ["invalid_", "valid_name", "private"] - - ad2 = sanitize_table(invalid_table_with_index, inplace=False) - assert list(ad2.obs.columns) == ["invalid_name", "index"] - - # original fixture remains unchanged - assert list(invalid_table.obs.columns) == ["@invalid#", "valid_name", "__private"] - - -def test_sanitize_table_inplace_copy(invalid_table): - ad = invalid_table.copy() - sanitize_table(ad) # inplace=True is now default - assert list(ad.obs.columns) == ["invalid_", "valid_name", "private"] - - -def test_sanitize_table_case_insensitive_collisions(): - obs = pd.DataFrame( - { - "Column1": [1, 2], - "column1": [3, 4], - "COLUMN1": [5, 6], - } - ) - ad = AnnData(obs=obs) - sanitized = sanitize_table(ad, inplace=False) - cols = list(sanitized.obs.columns) - assert sorted(cols) == sorted(["Column1", "column1_1", "COLUMN1_2"]) - - -def test_sanitize_table_whitespace_collision(): - """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'.""" - obs = pd.DataFrame({"a b": [1], "a_b": [2]}) - ad = AnnData(obs=obs) - sanitized = sanitize_table(ad, inplace=False) - cols = list(sanitized.obs.columns) - assert "a_b" in cols - assert "a_b_1" in cols - - -# ----------------------------------------------------------------------------- -# sanitize_table attribute‐specific tests -# ----------------------------------------------------------------------------- - - -def test_sanitize_table_obs_and_obs_columns(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) - sanitized = sanitize_table(ad, inplace=False) - assert list(sanitized.obs.columns) == ["col"] - - -def test_sanitize_table_obsm_and_obsp(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) - ad.obsm["@col"] = np.array([[1, 2], [3, 4]]) - ad.obsp["bad name"] = np.array([[1, 2], [3, 4]]) - sanitized = sanitize_table(ad, inplace=False) - assert list(sanitized.obsm.keys()) == ["col"] - assert list(sanitized.obsp.keys()) == ["bad_name"] - - -def test_sanitize_table_varm_and_varp(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) - ad.varm["__priv"] = np.array([[1, 2], [3, 4]]) - ad.varp["_index"] = np.array([[1, 2], [3, 4]]) - sanitized = sanitize_table(ad, inplace=False) - assert list(sanitized.varm.keys()) == ["priv"] - assert list(sanitized.varp.keys()) == ["index"] - - -def test_sanitize_table_uns_and_layers(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) - ad.uns["bad@key"] = "val" - ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]]) - sanitized = sanitize_table(ad, inplace=False) - assert list(sanitized.uns.keys()) == ["bad_key"] - assert list(sanitized.layers.keys()) == ["bad_layer"] - - -def test_sanitize_table_empty_returns_empty(): - ad = AnnData() - sanitized = sanitize_table(ad, inplace=False) - assert isinstance(sanitized, AnnData) - assert sanitized.obs.empty - assert sanitized.var.empty - - -def test_sanitize_table_preserves_underlying_data(): - ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]})) - ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]]) - ad.uns["invalid@key"] = "value" - sanitized = sanitize_table(ad, inplace=False) - assert sanitized.obs["invalid_"].tolist() == [1, 2] - assert sanitized.obs["valid"].tolist() == [3, 4] - assert np.array_equal(sanitized.obsm["invalid_"], np.array([[1, 2], [3, 4]])) - assert sanitized.uns["invalid_key"] == "value" - - -# ----------------------------------------------------------------------------- -# SpatialData integration -# ----------------------------------------------------------------------------- - - -def test_sanitize_table_in_spatialdata_sanitized_fixture(sdata_sanitized_tables): - t1 = sdata_sanitized_tables.tables["table1"] - t2 = sdata_sanitized_tables.tables["table2"] - assert list(t1.obs.columns) == ["invalid_", "valid_name", "private"] - assert list(t2.obs.columns) == ["invalid_name", "index"] - - -def test_spatialdata_retains_other_elements(full_sdata, sdata_sanitized_tables): - # Add another sanitized table into an existing full_sdata - tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]})) - sanitize_table(tbl) - full_sdata.tables["new_table"] = tbl - - # Verify columns and presence of other SpatialData attributes - assert list(full_sdata.tables["new_table"].obs.columns) == ["foo_", "bar"] - assert "image2d" in full_sdata.images - assert "labels2d" in full_sdata.labels - assert "points_0" in full_sdata.points From 1a800c7fa39cfd692c322413ea422cf7f8918ac1 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sun, 18 May 2025 17:15:11 +0100 Subject: [PATCH 6/7] copilot feedback --- src/spatialdata/_core/concatenate.py | 5 ++++- tests/core/test_concatenate.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 7ae3e1a77..526049199 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -116,6 +116,8 @@ def concatenate( modify_tables_inplace Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables will be copied before modification. Copying is enabled by default but can be disabled for performance reasons. + merge_coordinate_systems_on_name + Whether to keep coordinate system names unchanged (True) or add suffixes (False). attrs_merge How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html) kwargs @@ -272,7 +274,8 @@ def _fix_ensure_unique_element_names( # Get transformations from original element transformations = get_transformation(original_element, get_all=True) - assert isinstance(transformations, dict) + if not isinstance(transformations, dict): + raise TypeError(f"Expected 'transformations' to be a dict, but got {type(transformations).__name__}.") # Remove any existing transformations from the new element remove_transformation(element, remove_all=True) diff --git a/tests/core/test_concatenate.py b/tests/core/test_concatenate.py index 2bc68a9b9..0eeb06feb 100644 --- a/tests/core/test_concatenate.py +++ b/tests/core/test_concatenate.py @@ -31,6 +31,6 @@ def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_o assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) if merge_coordinate_systems_on_name: - assert sdata.coordinate_systems == ["global"] + assert set(sdata.coordinate_systems) == {"global"} else: - assert sdata.coordinate_systems == ["global-blob1", "global-blob2"] + assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"} From c99e30493f697ae1809680eab46b75dee86e60df Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 27 May 2025 12:13:25 -0400 Subject: [PATCH 7/7] code review with changes --- src/spatialdata/_core/concatenate.py | 35 ++++++++---------- .../operations/test_spatialdata_operations.py | 34 ++++++++++++++++++ tests/core/test_concatenate.py | 36 ------------------- 3 files changed, 48 insertions(+), 57 deletions(-) delete mode 100644 tests/core/test_concatenate.py diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 526049199..87a9a0324 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -236,7 +236,20 @@ def _fix_ensure_unique_element_names( sdatas_fixed = [] for suffix, sdata in sdatas.items(): # Create new elements dictionary with suffixed names - elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()} + elements = {} + for _, name, el in sdata.gen_spatial_elements(): + new_element_name = f"{name}-{suffix}" + if not merge_coordinate_systems_on_name: + # Set new transformations with suffixed coordinate system names + transformations = get_transformation(el, get_all=True) + assert isinstance(transformations, dict) + + remove_transformation(el, remove_all=True) + for cs, t in transformations.items(): + new_cs = f"{cs}-{suffix}" + set_transformation(el, t, to_coordinate_system=new_cs) + + elements[new_element_name] = el # Handle tables with suffix tables = {} @@ -265,25 +278,5 @@ def _fix_ensure_unique_element_names( # Create new SpatialData object with suffixed elements and tables sdata_fixed = SpatialData.init_from_elements(elements, tables=tables) - - # Handle coordinate systems and transformations - for element_name, element in elements.items(): - # Get the original element from the input sdata - original_name = element_name.replace(f"-{suffix}", "") - original_element = sdata.get(original_name) - - # Get transformations from original element - transformations = get_transformation(original_element, get_all=True) - if not isinstance(transformations, dict): - raise TypeError(f"Expected 'transformations' to be a dict, but got {type(transformations).__name__}.") - - # Remove any existing transformations from the new element - remove_transformation(element, remove_all=True) - - # Set new transformations with suffixed coordinate system names - for cs, t in transformations.items(): - new_cs = cs if merge_coordinate_systems_on_name else f"{cs}-{suffix}" - set_transformation(element, t, to_coordinate_system=new_cs) - sdatas_fixed.append(sdata_fixed) return sdatas_fixed diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index dd765029d..9124c60ec 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -471,6 +471,40 @@ def _n_elements(sdata: SpatialData) -> int: assert "blobs_image-sample" in c.images +@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False]) +def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name): + blob1 = blobs() + blob2 = blobs() + + sdata_keys = ["blob1", "blob2"] + sdata = concatenate( + dict(zip(sdata_keys, [blob1, blob2], strict=True)), + merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, + ) + + if merge_coordinate_systems_on_name: + assert set(sdata.coordinate_systems) == {"global"} + else: + assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"} + + # extra checks not specific to this test, we could remove them or leave them just + # in case + expected_images = ["blobs_image", "blobs_multiscale_image"] + expected_labels = ["blobs_labels", "blobs_multiscale_labels"] + expected_points = ["blobs_points"] + expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"] + + expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images] + expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels] + expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points] + expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes] + + assert set(sdata.images.keys()) == set(expected_suffixed_images) + assert set(sdata.labels.keys()) == set(expected_suffixed_labels) + assert set(sdata.points.keys()) == set(expected_suffixed_points) + assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) + + def test_locate_spatial_element(full_sdata: SpatialData) -> None: assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"] diff --git a/tests/core/test_concatenate.py b/tests/core/test_concatenate.py deleted file mode 100644 index 0eeb06feb..000000000 --- a/tests/core/test_concatenate.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -import spatialdata as sd -from spatialdata.datasets import blobs - - -@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False]) -def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name): - blob1 = blobs() - blob2 = blobs() - - sdata_keys = ["blob1", "blob2"] - sdata = sd.concatenate( - dict(zip(sdata_keys, [blob1, blob2], strict=True)), - merge_coordinate_systems_on_name=merge_coordinate_systems_on_name, - ) - - expected_images = ["blobs_image", "blobs_multiscale_image"] - expected_labels = ["blobs_labels", "blobs_multiscale_labels"] - expected_points = ["blobs_points"] - expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"] - - expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images] - expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels] - expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points] - expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes] - - assert set(sdata.images.keys()) == set(expected_suffixed_images) - assert set(sdata.labels.keys()) == set(expected_suffixed_labels) - assert set(sdata.points.keys()) == set(expected_suffixed_points) - assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes) - - if merge_coordinate_systems_on_name: - assert set(sdata.coordinate_systems) == {"global"} - else: - assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"}