Skip to content

Commit bcf69d2

Browse files
Replacing fixed size tuples as return types (#1043)
* change return types and update tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mypy problems * fix typing --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6621650 commit bcf69d2

File tree

6 files changed

+86
-33
lines changed

6 files changed

+86
-33
lines changed

src/squidpy/gr/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,31 @@
22

33
from __future__ import annotations
44

5-
from squidpy.gr._build import mask_graph, spatial_neighbors
5+
from squidpy.gr._build import SpatialNeighborsResult, mask_graph, spatial_neighbors
66
from squidpy.gr._ligrec import ligrec
7-
from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment
7+
from squidpy.gr._nhood import (
8+
NhoodEnrichmentResult,
9+
centrality_scores,
10+
interaction_matrix,
11+
nhood_enrichment,
12+
)
813
from squidpy.gr._niche import calculate_niche
914
from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr
1015
from squidpy.gr._ripley import ripley
1116
from squidpy.gr._sepal import sepal
17+
18+
__all__ = [
19+
"SpatialNeighborsResult",
20+
"NhoodEnrichmentResult",
21+
"mask_graph",
22+
"spatial_neighbors",
23+
"ligrec",
24+
"centrality_scores",
25+
"interaction_matrix",
26+
"nhood_enrichment",
27+
"calculate_niche",
28+
"co_occurrence",
29+
"spatial_autocorr",
30+
"ripley",
31+
"sepal",
32+
]

src/squidpy/gr/_build.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Iterable # noqa: F401
77
from functools import partial
88
from itertools import chain
9-
from typing import Any, cast
9+
from typing import Any, NamedTuple, cast
1010

1111
import geopandas as gpd
1212
import numpy as np
@@ -61,6 +61,21 @@
6161
__all__ = ["spatial_neighbors"]
6262

6363

64+
class SpatialNeighborsResult(NamedTuple):
65+
"""Result of spatial_neighbors function.
66+
67+
Attributes
68+
----------
69+
connectivities
70+
Spatial connectivities matrix.
71+
distances
72+
Spatial distances matrix.
73+
"""
74+
75+
connectivities: csr_matrix
76+
distances: csr_matrix
77+
78+
6479
@d.dedent
6580
@inject_docs(t=Transform, c=CoordType)
6681
def spatial_neighbors(
@@ -79,7 +94,7 @@ def spatial_neighbors(
7994
set_diag: bool = False,
8095
key_added: str = "spatial",
8196
copy: bool = False,
82-
) -> tuple[csr_matrix, csr_matrix] | None:
97+
) -> SpatialNeighborsResult | None:
8398
"""
8499
Create a graph from spatial coordinates.
85100
@@ -136,7 +151,7 @@ def spatial_neighbors(
136151
137152
Returns
138153
-------
139-
If ``copy = True``, returns a :class:`tuple` with the spatial connectivities and distances matrices.
154+
If ``copy = True``, returns a :class:`~squidpy.gr.SpatialNeighborsResult` with the spatial connectivities and distances matrices.
140155
141156
Otherwise, modifies the ``adata`` with the following keys:
142157
@@ -259,7 +274,7 @@ def spatial_neighbors(
259274
}
260275

261276
if copy:
262-
return Adj, Dst
277+
return SpatialNeighborsResult(connectivities=Adj, distances=Dst)
263278

264279
_save_data(adata, attr="obsp", key=conns_key, data=Adj)
265280
_save_data(adata, attr="obsp", key=dists_key, data=Dst, prefix=False)

src/squidpy/gr/_nhood.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from collections.abc import Callable, Iterable, Sequence
66
from functools import partial
7-
from typing import Any
7+
from typing import Any, NamedTuple
88

99
import networkx as nx
1010
import numba.types as nt
1111
import numpy as np
1212
import pandas as pd
1313
from anndata import AnnData
1414
from numba import njit, prange # noqa: F401
15+
from numpy.typing import NDArray
1516
from pandas import CategoricalDtype
1617
from scanpy import logging as logg
1718
from spatialdata import SpatialData
@@ -30,7 +31,24 @@
3031

3132
__all__ = ["nhood_enrichment", "centrality_scores", "interaction_matrix"]
3233

33-
dt = nt.uint32 # data type aliases (both for numpy and numba should match)
34+
35+
class NhoodEnrichmentResult(NamedTuple):
36+
"""Result of nhood_enrichment function.
37+
38+
Attributes
39+
----------
40+
zscore : NDArray[np.number]
41+
Z-score values of enrichment statistic.
42+
count : NDArray[np.number]
43+
Enrichment count.
44+
"""
45+
46+
zscore: NDArray[np.number]
47+
counts: NDArray[np.number] # NamedTuple inherits from tuple so cannot use 'count' as attribute name
48+
49+
50+
# data type aliases (both for numpy and numba should match)
51+
dt = nt.uint32
3452
ndt = np.uint32
3553
_template = """
3654
@njit(dt[:, :](dt[:], dt[:], dt[:]), parallel={parallel}, fastmath=True)
@@ -131,7 +149,7 @@ def nhood_enrichment(
131149
n_jobs: int | None = None,
132150
backend: str = "loky",
133151
show_progress_bar: bool = True,
134-
) -> tuple[NDArrayA, NDArrayA] | None:
152+
) -> NhoodEnrichmentResult | None:
135153
"""
136154
Compute neighborhood enrichment by permutation test.
137155
@@ -149,7 +167,7 @@ def nhood_enrichment(
149167
150168
Returns
151169
-------
152-
If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count.
170+
If ``copy = True``, returns a :class:`~squidpy.gr.NhoodEnrichmentResult` with the z-score and the enrichment count.
153171
154172
Otherwise, modifies the ``adata`` with the following keys:
155173
@@ -202,7 +220,7 @@ def nhood_enrichment(
202220
zscore = (count - perms.mean(axis=0)) / perms.std(axis=0)
203221

204222
if copy:
205-
return zscore, count
223+
return NhoodEnrichmentResult(zscore=zscore, counts=count)
206224

207225
_save_data(
208226
adata,

src/squidpy/im/_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _assert_dims_present(dims: tuple[str, ...], include_z: bool = True) -> None:
2525

2626
# modification of `skimage`'s `pil_to_ndarray`:
2727
# https://github.com/scikit-image/scikit-image/blob/main/skimage/io/_plugins/pil_plugin.py#L55
28-
def _infer_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]:
28+
def _infer_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype[np.generic]]:
2929
def _palette_is_grayscale(pil_image: Image.Image) -> bool:
3030
# get palette as an array with R, G, B columns
3131
palette = np.asarray(pil_image.getpalette()).reshape((256, 3))
@@ -81,7 +81,7 @@ def _palette_is_grayscale(pil_image: Image.Image) -> bool:
8181
raise ValueError(f"Unable to infer image dtype for image mode `{image.mode}`.")
8282

8383

84-
def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]:
84+
def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype[np.generic]]:
8585
try:
8686
return _infer_shape_dtype(fname)
8787
except Image.UnidentifiedImageError as e:
@@ -101,7 +101,7 @@ def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]:
101101
def _infer_dimensions(
102102
obj: NDArrayA | xr.DataArray | str,
103103
infer_dimensions: InferDimensions | tuple[str, ...] = InferDimensions.DEFAULT,
104-
) -> tuple[tuple[int, ...], tuple[str, ...], np.dtype, tuple[int, ...]]:
104+
) -> tuple[tuple[int, ...], tuple[str, ...], np.dtype[np.generic], tuple[int, ...]]:
105105
"""
106106
Infer dimension names of an array.
107107

tests/graph/test_nhood.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def test_reproducibility(self, adata: AnnData, n_jobs: int):
4949
assert len(res1) == len(res2)
5050
assert len(res2) == len(res3)
5151

52-
for key in range(len(res1)):
53-
np.testing.assert_array_equal(res2[key], res1[key])
54-
if key == 0: # z-score
55-
with pytest.raises(AssertionError):
56-
np.testing.assert_array_equal(res3[key], res2[key])
57-
else: # counts
58-
np.testing.assert_array_equal(res3[key], res2[key])
52+
# Test that the same seed produces the same results
53+
np.testing.assert_array_equal(res2.zscore, res1.zscore)
54+
np.testing.assert_array_equal(res2.counts, res1.counts)
55+
56+
# Test that different seeds produce different z-scores but same counts
57+
with pytest.raises(AssertionError):
58+
np.testing.assert_array_equal(res3.zscore, res2.zscore)
59+
np.testing.assert_array_equal(res3.counts, res2.counts)
5960

6061

6162
def test_centrality_scores(nhood_data: AnnData):

tests/graph/test_spatial_neighbors.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,25 +198,23 @@ def test_radius_min_max(self, non_visium_adata: AnnData, radius: tuple[float, fl
198198
np.testing.assert_allclose(spatial_dist, gt_ddist)
199199

200200
def test_copy(self, non_visium_adata: AnnData):
201-
conn, dist = spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, copy=True)
201+
result = spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, copy=True)
202202

203-
assert isspmatrix_csr(conn)
204-
assert isspmatrix_csr(dist)
203+
assert isspmatrix_csr(result.connectivities)
204+
assert isspmatrix_csr(result.distances)
205205
assert Key.obsp.spatial_conn() not in non_visium_adata.obsp
206206
assert Key.obsp.spatial_dist() not in non_visium_adata.obsp
207-
np.testing.assert_allclose(dist.toarray(), self._gt_ddist)
208-
np.testing.assert_allclose(conn.toarray(), self._gt_dgraph)
207+
np.testing.assert_allclose(result.distances.toarray(), self._gt_ddist)
208+
np.testing.assert_allclose(result.connectivities.toarray(), self._gt_dgraph)
209209

210210
@pytest.mark.parametrize("percentile", [99.0, 95.0])
211211
def test_percentile_filtering(self, adata_hne: AnnData, percentile: float, coord_type="generic"):
212-
conn, dist = spatial_neighbors(adata_hne, coord_type=coord_type, copy=True)
213-
conn_filtered, dist_filtered = spatial_neighbors(
214-
adata_hne, coord_type=coord_type, percentile=percentile, copy=True
215-
)
212+
result = spatial_neighbors(adata_hne, coord_type=coord_type, copy=True)
213+
result_filtered = spatial_neighbors(adata_hne, coord_type=coord_type, percentile=percentile, copy=True)
216214

217215
# check whether there are less connectivities in the filtered graph and whether the max distance is smaller
218-
assert not ((conn != conn_filtered).nnz == 0)
219-
assert dist.max() > dist_filtered.max()
216+
assert not ((result.connectivities != result_filtered.connectivities).nnz == 0)
217+
assert result.distances.max() > result_filtered.distances.max()
220218

221219
Adj, Dst = _build_connectivity(adata_hne.obsm["spatial"], n_neighs=6, return_distance=True, set_diag=False)
222220
threshold = np.percentile(Dst.data, percentile)
@@ -225,7 +223,7 @@ def test_percentile_filtering(self, adata_hne: AnnData, percentile: float, coord
225223
Adj.eliminate_zeros()
226224
Dst.eliminate_zeros()
227225

228-
assert dist_filtered.max() == Dst.max()
226+
assert result_filtered.distances.max() == Dst.max()
229227

230228
@pytest.mark.parametrize("n_neighs", [5, 10, 20])
231229
def test_spatial_neighbors_generic(self, n_neighs: int):

0 commit comments

Comments
 (0)