Skip to content

Commit 28eacc9

Browse files
Merge pull request #365 from scverse/performance/xenium-shapes-parsing
Faster xenium polygon parsing via ragged arrays
2 parents 1431911 + dc8e297 commit 28eacc9

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

src/spatialdata_io/__main__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def visium_hd_wrapper(
480480
default=True,
481481
help="Whether to read cells annotations in the AnnData table. [default: True]",
482482
)
483-
@click.option("--n-jobs", type=int, default=1, help="Number of jobs. [default: 1]")
484483
def xenium_wrapper(
485484
input: str,
486485
output: str,
@@ -495,7 +494,6 @@ def xenium_wrapper(
495494
morphology_focus: bool = True,
496495
aligned_images: bool = True,
497496
cells_table: bool = True,
498-
n_jobs: int = 1,
499497
) -> None:
500498
"""Xenium conversion to SpatialData."""
501499
sdata = xenium( # type: ignore[name-defined] # noqa: F821
@@ -510,7 +508,6 @@ def xenium_wrapper(
510508
morphology_focus=morphology_focus,
511509
aligned_images=aligned_images,
512510
cells_table=cells_table,
513-
n_jobs=n_jobs,
514511
)
515512
sdata.write(output)
516513

src/spatialdata_io/readers/xenium.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from geopandas import GeoDataFrame
2323
from joblib import Parallel, delayed
2424
from pyarrow import Table
25-
from shapely import Polygon
25+
from shapely import GeometryType, Polygon, from_ragged_array
2626
from spatialdata import SpatialData
2727
from spatialdata._core.query.relational_query import get_element_instances
2828
from spatialdata.models import (
@@ -69,7 +69,6 @@ def xenium(
6969
morphology_focus: bool = True,
7070
aligned_images: bool = True,
7171
cells_table: bool = True,
72-
n_jobs: int = 1,
7372
gex_only: bool = True,
7473
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
7574
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
@@ -122,8 +121,6 @@ def xenium(
122121
`False` and use the `xenium_aligned_image` function directly.
123122
cells_table
124123
Whether to read the cell annotations in the `AnnData` table.
125-
n_jobs
126-
Number of jobs to use for parallel processing.
127124
gex_only
128125
Whether to load only the "Gene Expression" feature type.
129126
imread_kwargs
@@ -261,7 +258,6 @@ def xenium(
261258
path,
262259
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
263260
specs,
264-
n_jobs,
265261
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
266262
)
267263

@@ -270,7 +266,6 @@ def xenium(
270266
path,
271267
XeniumKeys.CELL_BOUNDARIES_FILE,
272268
specs,
273-
n_jobs,
274269
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
275270
)
276271

@@ -406,47 +401,63 @@ def filter(self, record: logging.LogRecord) -> bool:
406401

407402
def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
408403
if isinstance(cell_id_column.iloc[0], bytes):
409-
return cell_id_column.apply(lambda x: x.decode("utf-8"))
404+
return cell_id_column.str.decode("utf-8")
410405
return cell_id_column
411406

412407

413408
def _get_polygons(
414409
path: Path,
415410
file: str,
416411
specs: dict[str, Any],
417-
n_jobs: int,
418412
idx: ArrayLike | None = None,
419413
) -> GeoDataFrame:
420414
# seems to be faster than pd.read_parquet
421415
df = pq.read_table(path / file).to_pandas()
416+
cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
417+
x = df[XeniumKeys.BOUNDARIES_VERTEX_X].to_numpy()
418+
y = df[XeniumKeys.BOUNDARIES_VERTEX_Y].to_numpy()
419+
coords = np.column_stack([x, y])
420+
421+
change_mask = np.concatenate([[True], cell_ids[1:] != cell_ids[:-1]])
422+
group_starts = np.where(change_mask)[0]
423+
group_ends = np.concatenate([group_starts[1:], [len(cell_ids)]])
424+
425+
# sanity check
426+
n_unique_ids = len(df[XeniumKeys.CELL_ID].drop_duplicates())
427+
if len(group_starts) != n_unique_ids:
428+
raise ValueError(
429+
f"In {file}, rows belonging to the same polygon must be contiguous. "
430+
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
431+
f"This indicates non-consecutive polygon rows."
432+
)
422433

423-
group_by = df.groupby(XeniumKeys.CELL_ID)
424-
index = pd.Series(group_by.indices.keys())
425-
# convert the index to str since we will compare it with an AnnData object, where the index is a str
426-
index.index = index.index.astype(str)
427-
index = _decode_cell_id_column(index)
434+
unique_ids = cell_ids[group_starts]
435+
436+
# offsets for ragged array:
437+
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
438+
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
439+
ring_offsets = np.concatenate([[0], group_ends]) # vertex positions
440+
geom_offsets = np.arange(len(group_starts) + 1) # [0, 1, 2, ..., n_polygons]
441+
442+
geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))
443+
444+
index = _decode_cell_id_column(pd.Series(unique_ids))
445+
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
428446

429-
out = Parallel(n_jobs=n_jobs)(
430-
delayed(Polygon)(i.to_numpy())
431-
for _, i in group_by[[XeniumKeys.BOUNDARIES_VERTEX_X, XeniumKeys.BOUNDARIES_VERTEX_Y]]
432-
)
433-
geo_df = GeoDataFrame({"geometry": out})
434447
version = _parse_version_of_xenium_analyzer(specs)
435448
if version is not None and version < packaging.version.parse("2.0.0"):
436449
assert idx is not None
437450
assert len(idx) == len(geo_df)
438-
assert np.unique(geo_df.index).size == len(geo_df)
439451
assert index.equals(idx)
440-
geo_df.index = idx
441452
else:
442-
geo_df.index = index
443-
if not np.unique(geo_df.index).size == len(geo_df):
453+
if np.unique(geo_df.index).size != len(geo_df):
444454
warnings.warn(
445455
"Found non-unique polygon indices, this will be addressed in a future version of the reader. For the "
446456
"time being please consider merging polygons with non-unique indices into single multi-polygons.",
447457
UserWarning,
448458
stacklevel=2,
449459
)
460+
450461
scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
451462
return ShapesModel.parse(geo_df, transformations={"global": scale})
452463

0 commit comments

Comments
 (0)