Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e67cc1
fix
selmanozleyen Aug 24, 2025
dbdda0c
attempt to add tests and add better documentation to the functions to…
selmanozleyen Aug 29, 2025
a3366e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2025
b3ddcb7
add vibe coded tests prune later
selmanozleyen Aug 29, 2025
2fffcfd
add tests for spectral transform
selmanozleyen Sep 1, 2025
eeab937
Merge branch 'fix/spectral-transform' of https://github.com/scverse/s…
selmanozleyen Sep 4, 2025
2ce6e85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
651505a
complete the merge conflict issue
selmanozleyen Sep 4, 2025
303d5c4
add fau and clean up code
selmanozleyen Sep 4, 2025
8c881be
make loops parallel
selmanozleyen Sep 4, 2025
bcb5989
Merge branch 'fix/spectral-transform' of https://github.com/scverse/s…
selmanozleyen Sep 4, 2025
172080d
specify fast arrayutils dep
selmanozleyen Sep 4, 2025
806273a
Merge branch 'main' into fix/spectral-transform
selmanozleyen Sep 4, 2025
a3db2d0
Merge branch 'main' into fix/spectral-transform
flying-sheep Sep 4, 2025
ea457c3
Merge branch 'main' into fix/spectral-transform
selmanozleyen Oct 1, 2025
6e60c6c
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 3, 2025
22acdad
Add fast_array_utils to project dependencies
selmanozleyen Nov 3, 2025
017a763
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 7, 2025
126fd38
cache kernel
selmanozleyen Nov 7, 2025
a24419a
forgot to save file bf commit
selmanozleyen Nov 7, 2025
4eae80c
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 7, 2025
4f6010b
Apply suggestion from @flying-sheep
flying-sheep Nov 7, 2025
8f178cc
remove unused imports (idk why this didn't fail on linter)
selmanozleyen Nov 7, 2025
c7c8f83
remove more unused imports
selmanozleyen Nov 7, 2025
3d82ed5
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 10, 2025
6680ec2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
0086dac
move to float32
selmanozleyen Nov 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies = [
"xarray>=2024.10.0",
"zarr>=2.6.1,<3.0.0",
"spatialdata>=0.2.5",
"fast_array_utils",
]

[project.optional-dependencies]
Expand Down
66 changes: 52 additions & 14 deletions src/squidpy/gr/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import pandas as pd
from anndata import AnnData
from anndata.utils import make_index_unique
from geopandas import GeoDataFrame
from numba import njit
from fast_array_utils import stats as fau_stats
from numba import njit, prange
from scanpy import logging as logg
from scipy.sparse import (
SparseEfficiencyWarning,
Expand Down Expand Up @@ -453,29 +453,67 @@ def _build_connectivity(
return Adj


@njit
def outer(indices: NDArrayA, indptr: NDArrayA, degrees: NDArrayA) -> NDArrayA:
res = np.empty_like(indices, dtype=np.float64)
start = 0
for i in range(len(indptr) - 1):
@njit(parallel=True)
def _csr_bilateral_diag_scale_helper(
data: NDArrayA, indices: NDArrayA, indptr: NDArrayA, degrees: NDArrayA
) -> NDArrayA:
"""
Return an array F aligned with CSR non-zeros such that
F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order.

Parameters
----------

data : array of float
CSR `data` (non-zero values).
indices : array of int
CSR `indices` (column indices).
indptr : array of int
CSR `indptr` (row pointer).
degrees : array of float, shape (n,)
Diagonal scaling vector.

Returns
-------
array of float
Length equals len(data). Entry-wise factors d_i * d_j * data[k]
"""

res = np.empty_like(data, dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why np.float64?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because original function used np.float64. Didn't want to change old behaviour

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check if this is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I can't tell tbh because there was no tests for the spectral transform lol. We can just use float32 I guess. But the better way would be to ask ths users of spatial_neighbours

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Intron7 ok the results seem similar when I checked them. I also added some numerical tests to ensure stability. Also when I recalled that this was just a normalization of weights in a graph so it should not be very sensitive to precision in theory.

for i in prange(len(indptr) - 1):
ixs = indices[indptr[i] : indptr[i + 1]]
res[start : start + len(ixs)] = degrees[i] * degrees[ixs]
start += len(ixs)
res[indptr[i] : indptr[i + 1]] = degrees[i] * degrees[ixs] * data[indptr[i] : indptr[i + 1]]

return res


def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix:
"""
Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj.


Parameters
----------
adj : scipy.sparse.csr_matrix

Returns
-------
scipy.sparse.csr_matrix
"""
degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0))))
if adj.shape[0] != len(degrees):
raise ValueError("len(degrees) must equal number of rows of adj")
res_data = _csr_bilateral_diag_scale_helper(adj.data, adj.indices, adj.indptr, degrees)
return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape)


def _transform_a_spectral(a: spmatrix) -> spmatrix:
if not isspmatrix_csr(a):
a = a.tocsr()
if not a.nnz:
return a

degrees = np.squeeze(np.array(np.sqrt(1.0 / a.sum(axis=0))))
a = a.multiply(outer(a.indices, a.indptr, degrees))
a.eliminate_zeros()

return a
return symmetric_normalize_csr(a)


def _transform_a_cosine(a: spmatrix) -> spmatrix:
Expand Down
45 changes: 45 additions & 0 deletions tests/graph/test_spatial_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,48 @@ def test_mask_graph(
negative_mask=True,
key_added=key_added,
)

def test_spatial_neighbors_transform_mathematical_properties(self, non_visium_adata: AnnData):
"""
Test mathematical properties of each transform.
"""
# Test spectral transform properties
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="spectral")
adj_spectral = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()

# Spectral transform should be symmetric
np.testing.assert_allclose(adj_spectral, adj_spectral.T, atol=1e-10)

# Spectral transform should have normalized rows (L2 norm <= 1)
row_norms = np.sqrt(np.sum(adj_spectral**2, axis=1))
np.testing.assert_array_less(row_norms, 1.0 + 1e-10)

# Test cosine transform properties
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="cosine")
adj_cosine = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()

# Cosine transform should be symmetric
np.testing.assert_allclose(adj_cosine, adj_cosine.T, atol=1e-10)

# Cosine transform should have values in [-1, 1]
np.testing.assert_array_less(-1.0 - 1e-10, adj_cosine)
np.testing.assert_array_less(adj_cosine, 1.0 + 1e-10)

# Diagonal of cosine transform should be 1 (self-similarity)
np.testing.assert_allclose(np.diag(adj_cosine), 1.0, atol=1e-10)

def test_spatial_neighbors_transform_edge_cases(self, non_visium_adata: AnnData):
"""
Test transforms with edge cases (empty graph, single node, etc.).
"""
# Test with a very small dataset
small_adata = non_visium_adata[:5].copy() # Only 5 points

# Test all transforms with small dataset
for transform in [None, "spectral", "cosine"]:
spatial_neighbors(small_adata, delaunay=True, coord_type=None, transform=transform)
assert Key.obsp.spatial_conn() in small_adata.obsp
assert Key.obsp.spatial_dist() in small_adata.obsp

# Verify transform parameter is saved
assert small_adata.uns[Key.uns.spatial_neighs()]["params"]["transform"] == transform