Skip to content

Commit fd6222d

Browse files
Fix bug custom region_key concatenate (#871)
fix bug custom region_key concatenate
1 parent c206323 commit fd6222d

File tree

2 files changed

+114
-22
lines changed

2 files changed

+114
-22
lines changed

src/spatialdata/_core/concatenate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _concatenate_tables(
6565

6666
merged_table = ad.concat(tables_l, **kwargs)
6767
attrs = {
68-
TableModel.REGION_KEY: merged_table.obs[TableModel.REGION_KEY].unique().tolist(),
68+
TableModel.REGION_KEY: merged_table.obs[region_key].unique().tolist(),
6969
TableModel.REGION_KEY_KEY: region_key,
7070
TableModel.INSTANCE_KEY: instance_key,
7171
}

tests/core/operations/test_spatialdata_operations.py

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020
TableModel,
2121
get_table_keys,
2222
)
23-
from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical
24-
from spatialdata.transformations.operations import get_transformation, set_transformation
23+
from spatialdata.testing import (
24+
assert_elements_dict_are_identical,
25+
assert_spatial_data_objects_are_identical,
26+
)
27+
from spatialdata.transformations.operations import (
28+
get_transformation,
29+
set_transformation,
30+
)
2531
from spatialdata.transformations.transformations import (
2632
Affine,
2733
BaseTransformation,
@@ -30,7 +36,7 @@
3036
Sequence,
3137
Translation,
3238
)
33-
from tests.conftest import _get_table
39+
from tests.conftest import _get_shapes, _get_table
3440

3541

3642
def test_element_names_unique() -> None:
@@ -183,7 +189,10 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None
183189
del adata.uns[TableModel.ATTRS_KEY]
184190
del full_sdata.tables["table"]
185191
full_sdata.table = TableModel.parse(
186-
adata, region=["circles", "poly"], region_key="annotated_shapes", instance_key="instance_id"
192+
adata,
193+
region=["circles", "poly"],
194+
region_key="annotated_shapes",
195+
instance_key="instance_id",
187196
)
188197

189198
scale = Scale([2.0], axes=("x",))
@@ -201,11 +210,19 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None
201210
def test_rename_coordinate_systems(full_sdata: SpatialData) -> None:
202211
# all the elements point to global, add new coordinate systems
203212
set_transformation(
204-
element=full_sdata.shapes["circles"], transformation=Identity(), to_coordinate_system="my_space0"
213+
element=full_sdata.shapes["circles"],
214+
transformation=Identity(),
215+
to_coordinate_system="my_space0",
216+
)
217+
set_transformation(
218+
element=full_sdata.shapes["poly"],
219+
transformation=Identity(),
220+
to_coordinate_system="my_space1",
205221
)
206-
set_transformation(element=full_sdata.shapes["poly"], transformation=Identity(), to_coordinate_system="my_space1")
207222
set_transformation(
208-
element=full_sdata.shapes["multipoly"], transformation=Identity(), to_coordinate_system="my_space2"
223+
element=full_sdata.shapes["multipoly"],
224+
transformation=Identity(),
225+
to_coordinate_system="my_space2",
209226
)
210227

211228
elements_in_global_before = {
@@ -234,7 +251,11 @@ def test_rename_coordinate_systems(full_sdata: SpatialData) -> None:
234251
# renaming, as it doesn't exist at the time of the function call)
235252
with pytest.raises(ValueError):
236253
full_sdata.rename_coordinate_systems(
237-
{"my_space00": "my_space3", "my_space11": "my_space3", "my_space3": "my_space4"}
254+
{
255+
"my_space00": "my_space3",
256+
"my_space11": "my_space3",
257+
"my_space3": "my_space4",
258+
}
238259
)
239260

240261
# valid renaming with collisions
@@ -276,33 +297,79 @@ def test_concatenate_tables() -> None:
276297
"instance_key": "instance_id",
277298
}
278299

279-
table3 = _get_table(region="shapes/circles", region_key="annotated_shapes_other", instance_key="instance_id")
300+
table3 = _get_table(
301+
region="shapes/circles",
302+
region_key="annotated_shapes_other",
303+
instance_key="instance_id",
304+
)
280305
with pytest.raises(ValueError):
281306
_concatenate_tables([table0, table3], region_key="region")
282307

283308
table4 = _get_table(
284-
region=["shapes/circles1", "shapes/poly1"], region_key="annotated_shape0", instance_key="instance_id"
309+
region=["shapes/circles1", "shapes/poly1"],
310+
region_key="annotated_shape0",
311+
instance_key="instance_id",
285312
)
286313
table5 = _get_table(
287-
region=["shapes/circles2", "shapes/poly2"], region_key="annotated_shape0", instance_key="instance_id"
314+
region=["shapes/circles2", "shapes/poly2"],
315+
region_key="annotated_shape0",
316+
instance_key="instance_id",
288317
)
289318
table6 = _get_table(
290-
region=["shapes/circles3", "shapes/poly3"], region_key="annotated_shape1", instance_key="instance_id"
319+
region=["shapes/circles3", "shapes/poly3"],
320+
region_key="annotated_shape1",
321+
instance_key="instance_id",
291322
)
292-
with pytest.raises(ValueError, match="`region_key` must be specified if tables have different region keys"):
323+
with pytest.raises(
324+
ValueError,
325+
match="`region_key` must be specified if tables have different region keys",
326+
):
293327
_concatenate_tables([table4, table5, table6])
294328
assert len(_concatenate_tables([table4, table5, table6], region_key="region")) == len(table4) + len(table5) + len(
295329
table6
296330
)
297331

298332

333+
def test_concatenate_custom_table_metadata() -> None:
334+
# test for https://github.com/scverse/spatialdata/issues/349
335+
shapes0 = _get_shapes()
336+
shapes1 = _get_shapes()
337+
n = len(shapes0["poly"])
338+
table0 = TableModel.parse(
339+
AnnData(obs={"my_region": ["poly0"] * n, "my_instance_id": list(range(n))}),
340+
region="poly0",
341+
region_key="my_region",
342+
instance_key="my_instance_id",
343+
)
344+
table1 = TableModel.parse(
345+
AnnData(obs={"my_region": ["poly1"] * n, "my_instance_id": list(range(n))}),
346+
region="poly1",
347+
region_key="my_region",
348+
instance_key="my_instance_id",
349+
)
350+
sdata0 = SpatialData.init_from_elements({"poly0": shapes0["poly"], "table": table0})
351+
sdata1 = SpatialData.init_from_elements({"poly1": shapes1["poly"], "table": table1})
352+
sdata = concatenate([sdata0, sdata1], concatenate_tables=True)
353+
assert len(sdata["table"]) == 2 * n
354+
355+
299356
def test_concatenate_sdatas(full_sdata: SpatialData) -> None:
300357
with pytest.raises(KeyError):
301358
concatenate([full_sdata, SpatialData(images={"image2d": full_sdata.images["image2d"]})])
302359
with pytest.raises(KeyError):
303-
concatenate([full_sdata, SpatialData(labels={"labels2d": full_sdata.labels["labels2d"]})])
360+
concatenate(
361+
[
362+
full_sdata,
363+
SpatialData(labels={"labels2d": full_sdata.labels["labels2d"]}),
364+
]
365+
)
304366
with pytest.raises(KeyError):
305-
concatenate([full_sdata, SpatialData(points={"points_0": full_sdata.points["points_0"]})])
367+
concatenate(
368+
[
369+
full_sdata,
370+
SpatialData(points={"points_0": full_sdata.points["points_0"]}),
371+
]
372+
)
306373
with pytest.raises(KeyError):
307374
concatenate([full_sdata, SpatialData(shapes={"circles": full_sdata.shapes["circles"]})])
308375

@@ -335,9 +402,15 @@ def test_concatenate_sdatas_from_iterable(concatenate_tables: bool, obs_names_ma
335402
sdatas = {"sample0": sdata0, "sample1": sdata1}
336403
with pytest.raises(KeyError, match="Images must have unique names across the SpatialData objects"):
337404
_ = concatenate(
338-
sdatas.values(), concatenate_tables=concatenate_tables, obs_names_make_unique=obs_names_make_unique
405+
sdatas.values(),
406+
concatenate_tables=concatenate_tables,
407+
obs_names_make_unique=obs_names_make_unique,
339408
)
340-
merged = concatenate(sdatas, obs_names_make_unique=obs_names_make_unique, concatenate_tables=concatenate_tables)
409+
merged = concatenate(
410+
sdatas,
411+
obs_names_make_unique=obs_names_make_unique,
412+
concatenate_tables=concatenate_tables,
413+
)
341414

342415
if concatenate_tables:
343416
assert len(merged.tables) == 1
@@ -449,10 +522,18 @@ def test_subset(full_sdata: SpatialData) -> None:
449522

450523
adata = AnnData(
451524
shape=(10, 0),
452-
obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]},
525+
obs={
526+
"region": ["circles"] * 5 + ["poly"] * 5,
527+
"instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
528+
},
453529
)
454530
del full_sdata.tables["table"]
455-
sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id")
531+
sdata_table = TableModel.parse(
532+
adata,
533+
region=["circles", "poly"],
534+
region_key="region",
535+
instance_key="instance_id",
536+
)
456537
full_sdata["table"] = sdata_table
457538
full_sdata.tables["second_table"] = sdata_table
458539
subset1 = full_sdata.subset(["poly", "second_table"])
@@ -494,7 +575,12 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning:
494575
points["z"] = points["x"]
495576
points = PointsModel.parse(points)
496577
full_sdata["points_0_3d"] = points
497-
sdata = transform_to_data_extent(full_sdata, "global", target_width=1000, maintain_positioning=maintain_positioning)
578+
sdata = transform_to_data_extent(
579+
full_sdata,
580+
"global",
581+
target_width=1000,
582+
maintain_positioning=maintain_positioning,
583+
)
498584

499585
first_a: ArrayLike | None = None
500586
for _, name, el in sdata.gen_spatial_elements():
@@ -505,7 +591,13 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning:
505591
first_a = a
506592
else:
507593
# we are not pixel perfect because of this bug: https://github.com/scverse/spatialdata/issues/165
508-
if maintain_positioning and name in ["points_0_3d", "points_0", "poly", "circles", "multipoly"]:
594+
if maintain_positioning and name in [
595+
"points_0_3d",
596+
"points_0",
597+
"poly",
598+
"circles",
599+
"multipoly",
600+
]:
509601
# Again, due to the "pixel perfect" bug, the 0.5 translation forth and back in the z axis that is added
510602
# by rasterize() (like the one in the example belows), amplifies the error also for x and y beyond the
511603
# rtol threshold below. So, let's skip that check and to an absolute check up to 0.5 (due to the

0 commit comments

Comments
 (0)