Skip to content

Commit aee1ca1

Browse files
committed
Allow including self in pre-computed distance matrices
1 parent 7b5c474 commit aee1ca1

File tree

6 files changed

+156
-43
lines changed

6 files changed

+156
-43
lines changed

src/cellmapper/cellmapper.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -539,34 +539,43 @@ def fit(
539539

540540
return self
541541

542-
def load_precomputed_distances(self, distances_key: str = "distances") -> None:
542+
def load_precomputed_distances(self, distances_key: str = "distances", include_self: bool | None = None) -> None:
543543
"""
544-
Load a pre-computed distance matrix from AnnData.obsp.
544+
Load precomputed distances from the AnnData object.
545+
546+
This method is only available in self-mapping mode.
545547
546548
Parameters
547549
----------
548550
distances_key
549-
Key in adata.obsp where the distance matrix is stored
551+
Key in adata.obsp where the precomputed distances are stored.
552+
include_self
553+
If True, include self as a neighbor (even if not present in the distance matrix).
554+
If False, exclude self connections (even if present in the distance matrix).
555+
If None (default), preserve the original behavior of the distance matrix.
550556
551557
Returns
552558
-------
553559
None
554560
555561
Notes
556562
-----
557-
This method can only be used in self-mapping mode (when CellMapper was
558-
initialized with query=None).
563+
Updates the following attributes:
564+
565+
- ``knn``: Neighbors object constructed from the precomputed distances.
559566
"""
560567
if not self._is_self_mapping:
561-
raise ValueError("Pre-computed distances can only be used in self-mapping mode")
568+
raise ValueError("load_precomputed_distances is only available in self-mapping mode.")
562569

563-
if distances_key not in self.query.obsp:
564-
raise KeyError(f"Distance matrix '{distances_key}' not found in query.obsp")
570+
# Access the precomputed distances
571+
distances_matrix = self.query.obsp[distances_key]
565572

566-
self.knn = Neighbors.from_distances(self.query.obsp[distances_key])
573+
# Create a neighbors object using the factory method
574+
self.knn = Neighbors.from_distances(distances_matrix, include_self=include_self)
567575

568576
logger.info(
569-
"Loaded pre-computed distance matrix from query.obsp['%s'] with %d cells",
577+
"Loaded precomputed distances from '%s' with %d cells and %d neighbors per cell.",
570578
distances_key,
571-
self.query.n_obs,
579+
distances_matrix.shape[0],
580+
self.knn.xx.n_neighbors,
572581
)

src/cellmapper/knn.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,22 +292,26 @@ def __init__(self, xrep: np.ndarray, yrep: np.ndarray | None = None):
292292
self._is_self_mapping = yrep is None
293293

294294
@classmethod
295-
def from_distances(cls, distances_matrix: "csr_matrix") -> "Neighbors":
295+
def from_distances(cls, distances_matrix: "csr_matrix", include_self: bool | None = None) -> "Neighbors":
296296
"""
297297
Create a Neighbors object from a pre-computed distances matrix.
298298
299299
Parameters
300300
----------
301301
distances_matrix
302302
Sparse distance matrix, typically from adata.obsp['distances']
303+
include_self
304+
If True, include self as a neighbor (cells are their own neighbors).
305+
If False, exclude self connections, even if present in the distance matrix.
306+
If None (default), preserve the original behavior of the distance matrix.
303307
304308
Returns
305309
-------
306310
Neighbors
307311
A new Neighbors object with pre-computed neighbor information
308312
"""
309313
# Extract indices and distances from the sparse matrix
310-
indices, distances = extract_neighbors_from_distances(distances_matrix)
314+
indices, distances = extract_neighbors_from_distances(distances_matrix, include_self=include_self)
311315

312316
# Create a minimal Neighbors object for self-mapping
313317
n_cells = distances_matrix.shape[0]
@@ -323,6 +327,9 @@ def from_distances(cls, distances_matrix: "csr_matrix") -> "Neighbors":
323327
neighbors.xy = neighbors_result
324328
neighbors.yx = neighbors_result
325329

330+
# Mark as self-mapping
331+
neighbors._is_self_mapping = True
332+
326333
logger.info("Created Neighbors object from distances matrix with %d cells", n_cells)
327334

328335
return neighbors

src/cellmapper/utils.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,20 @@ def create_imputed_anndata(
123123
return imputed_adata
124124

125125

126-
def extract_neighbors_from_distances(distances_matrix: "csr_matrix") -> tuple[np.ndarray, np.ndarray]:
126+
def extract_neighbors_from_distances(
127+
distances_matrix: "csr_matrix", include_self: bool | None = None
128+
) -> tuple[np.ndarray, np.ndarray]:
127129
"""
128130
Extract neighbor indices and distances from a sparse distance matrix.
129131
130132
Parameters
131133
----------
132134
distances_matrix
133135
Sparse matrix of distances, typically from adata.obsp['distances']
136+
include_self
137+
If True, include self as a neighbor (even if not present in the distance matrix).
138+
If False, exclude self connections (even if present in the distance matrix).
139+
If None (default), preserve the original behavior of the distance matrix.
134140
135141
Returns
136142
-------
@@ -139,44 +145,64 @@ def extract_neighbors_from_distances(distances_matrix: "csr_matrix") -> tuple[np
139145
"""
140146
# Check that the input is a sparse matrix
141147
if not issparse(distances_matrix):
142-
raise TypeError("Distances matrix must be sparse")
148+
raise TypeError("Distances matrix must be a sparse matrix")
149+
150+
# Verify that the matrix is square
151+
if distances_matrix.shape[0] != distances_matrix.shape[1]:
152+
raise ValueError(f"Square distance matrix required (got {distances_matrix.shape})")
143153

144154
n_cells = distances_matrix.shape[0]
145155

146-
# Get the number of neighbors per cell
147-
n_neighbors_per_cell = np.diff(distances_matrix.indptr)
148-
max_n_neighbors = n_neighbors_per_cell.max()
149-
min_n_neighbors = n_neighbors_per_cell.min()
150-
151-
# Check if all cells have the same number of neighbors
152-
if max_n_neighbors != min_n_neighbors:
153-
logger.warning(
154-
"Variable neighborhood sizes detected: min=%d, max=%d neighbors per cell. "
155-
"Some cells may have fewer neighbors than others, which could affect results.",
156-
min_n_neighbors,
157-
max_n_neighbors,
158-
)
156+
# Ensure the matrix is CSR format for efficient row-based operations
157+
distances_matrix = distances_matrix.tocsr()
158+
159+
# First pass: determine the max number of neighbors after including/excluding self
160+
max_n_neighbors = 0
161+
for i in range(n_cells):
162+
start, end = distances_matrix.indptr[i], distances_matrix.indptr[i + 1]
163+
cell_indices = distances_matrix.indices[start:end]
159164

160-
# Pre-allocate arrays for indices and distances
161-
# Use -1 as a sentinel value for missing neighbors (better than 0 which is a valid index)
165+
# Calculate how many neighbors this cell will have after applying include_self
166+
n_neighbors = len(cell_indices)
167+
if include_self is True and i not in cell_indices:
168+
n_neighbors += 1 # Will add self
169+
elif include_self is False and i in cell_indices:
170+
n_neighbors -= 1 # Will remove self
171+
172+
max_n_neighbors = max(max_n_neighbors, n_neighbors)
173+
174+
# Pre-allocate arrays for indices and distances with the correct size
162175
indices = np.full((n_cells, max_n_neighbors), -1, dtype=np.int64)
163176
distances = np.full((n_cells, max_n_neighbors), np.inf, dtype=np.float64)
164177

165-
# Extract indices and distances for each cell
178+
# Second pass: extract and process neighbor data
166179
for i in range(n_cells):
167180
# Get start and end indices for this cell in the sparse matrix
168181
start, end = distances_matrix.indptr[i], distances_matrix.indptr[i + 1]
169182

170-
# Number of neighbors for this cell
171-
n_neighbors = end - start
183+
# Get neighbor indices and distances
184+
cell_indices = distances_matrix.indices[start:end]
185+
cell_distances = distances_matrix.data[start:end]
186+
187+
# Filter self-connection if requested and present
188+
if include_self is False and i in cell_indices:
189+
# Find the index of self in the neighbors
190+
self_idx = np.where(cell_indices == i)[0]
191+
if len(self_idx) > 0:
192+
# Remove self from indices and distances
193+
mask = cell_indices != i
194+
cell_indices = cell_indices[mask]
195+
cell_distances = cell_distances[mask]
196+
# If include_self is True and self is not in the neighbors, add it
197+
elif include_self is True and i not in cell_indices:
198+
# Add self with distance 0
199+
cell_indices = np.append(cell_indices, i)
200+
cell_distances = np.append(cell_distances, 0.0)
201+
202+
# Number of neighbors after potential filtering
203+
n_neighbors = len(cell_indices)
172204

173205
if n_neighbors > 0:
174-
# Get neighbor indices
175-
cell_indices = distances_matrix.indices[start:end]
176-
177-
# Get distances
178-
cell_distances = distances_matrix.data[start:end]
179-
180206
# Sort by distance if they aren't already sorted
181207
if not np.all(np.diff(cell_distances) >= 0):
182208
sort_idx = np.argsort(cell_distances)

tests/test_cellmapper.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,45 @@ def test_load_squidpy_distances(self, adata_spatial, squidpy_params):
415415

416416
assert "leiden_pred" in cm.query.obs
417417
assert "leiden_conf" in cm.query.obs
418+
419+
@pytest.mark.parametrize("include_self", [True, False])
420+
def test_load_distances_with_include_self(self, adata_spatial):
421+
"""Test loading precomputed distances with and without self-connections."""
422+
423+
# Compute neighbors with scanpy
424+
sc.pp.neighbors(adata_spatial, n_neighbors=10, use_rep="X_pca")
425+
426+
# Initialize CellMapper in self-mapping mode
427+
cm_with_self = CellMapper(adata_spatial)
428+
cm_without_self = CellMapper(adata_spatial)
429+
430+
# Load precomputed distances with different include_self settings
431+
cm_with_self.load_precomputed_distances(distances_key="distances", include_self=True)
432+
cm_without_self.load_precomputed_distances(distances_key="distances", include_self=False)
433+
434+
# Verify that neighbors were loaded with or without self
435+
assert cm_with_self.knn is not None
436+
assert cm_without_self.knn is not None
437+
438+
# Check that with include_self=True, each cell has itself as a neighbor
439+
for i in range(min(10, cm_with_self.knn.xx.n_samples)): # Check first 10 cells
440+
assert i in cm_with_self.knn.xx.indices[i]
441+
442+
# Check that with include_self=False, no cell has itself as a neighbor
443+
for i in range(min(10, cm_without_self.knn.xx.n_samples)): # Check first 10 cells
444+
assert i not in cm_without_self.knn.xx.indices[i]
445+
446+
# Both should work with the rest of the pipeline
447+
cm_with_self.compute_mappping_matrix(method="gaussian")
448+
cm_without_self.compute_mappping_matrix(method="gaussian")
449+
450+
# Compute label transfer for both
451+
cm_with_self.transfer_labels(obs_keys="leiden", prediction_postfix="with_self")
452+
cm_without_self.transfer_labels(obs_keys="leiden", prediction_postfix="without_self")
453+
454+
# Both should have created prediction columns
455+
assert "leiden_with_self" in adata_spatial.obs
456+
assert "leiden_without_self" in adata_spatial.obs
457+
458+
# The results should be different (excluding self changes the neighborhood)
459+
assert not adata_spatial.obs["leiden_with_self"].equals(adata_spatial.obs["leiden_without_self"])

tests/test_neighbors.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def test_from_distances_factory_method(self, small_data):
7676
# Test with the neighbors results
7777
nn_results = neighbors.xx
7878
assert nn_results.n_samples == n_samples
79-
assert nn_results.n_neighbors == n_samples
79+
assert (
80+
nn_results.n_neighbors + 1 == n_samples
81+
) # the distance to self is exactly 0, so it won't be included as a neighbor
8082

8183
# Get adjacency matrix and verify it reflects the original distances
8284
adj_matrix = nn_results.knn_graph_distances
@@ -154,9 +156,10 @@ def test_from_distances_different_kernels(self, kernel):
154156
distances_data[i, j] = abs(i - j) # Simple metric: difference in indices
155157

156158
distances = csr_matrix(distances_data)
159+
print(distances)
157160

158161
# Create Neighbors object
159-
neighbors = Neighbors.from_distances(distances)
162+
neighbors = Neighbors.from_distances(distances, include_self=True)
160163

161164
# Compute connectivities with different kernels
162165
connectivities = neighbors.xx.knn_graph_connectivities(kernel=kernel)
@@ -167,8 +170,7 @@ def test_from_distances_different_kernels(self, kernel):
167170
assert np.all(connectivities.data > 0)
168171
# Diagonal should be large values (except for random kernel)
169172
if kernel != "random":
170-
diag_indices = list(range(n_samples))
171-
diag_values = connectivities[diag_indices, diag_indices].toarray().flatten()
173+
diag_values = connectivities.diagonal()
172174
# Diagonal elements should typically be the largest for each row
173175
for i in range(n_samples):
174176
row = connectivities.getrow(i).toarray().flatten()

tests/test_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,30 @@ def test_infinite_values(self):
8686
# Check that infinite values are excluded
8787
for i in range(4):
8888
assert np.all(np.isfinite(nbr_distances[i]))
89+
90+
def test_include_self_parameter(self):
91+
"""Test the include_self parameter to control self-connections."""
92+
# Create a distance matrix with self-connections (diagonal = 0)
93+
distances_data = np.array(
94+
[[0.0, 1.0, 2.0, 3.0], [1.0, 0.0, 4.0, 5.0], [2.0, 4.0, 0.0, 6.0], [3.0, 5.0, 6.0, 0.0]]
95+
)
96+
distances = csr_matrix(distances_data)
97+
98+
# Test with include_self=True (default)
99+
indices_with_self, distances_with_self = extract_neighbors_from_distances(distances, include_self=True)
100+
101+
# Test with include_self=False
102+
indices_without_self, distances_without_self = extract_neighbors_from_distances(distances, include_self=False)
103+
104+
# With include_self=True, diagonal elements should be included
105+
# (self should be the first neighbor because distance is 0)
106+
for i in range(4):
107+
assert indices_with_self[i, 0] == i
108+
assert distances_with_self[i, 0] == 0.0
109+
110+
# With include_self=False, diagonal elements should be excluded
111+
for i in range(4):
112+
# The self-index should not be in the neighbors
113+
assert i not in indices_without_self[i]
114+
# Check that no zero distances are present
115+
assert np.all(distances_without_self[i] > 0)

0 commit comments

Comments
 (0)