Skip to content

Commit f797ef9

Browse files
committed
Add cellcharter aggregation step
1 parent d6df534 commit f797ef9

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

src/squidpy/gr/_niche.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from collections.abc import Iterator
55
from typing import Any, Optional
66

7+
import anndata as ad
78
import numpy as np
89
import pandas as pd
910
import scanpy as sc
1011
from anndata import AnnData
12+
from scipy.sparse import csr_matrix, vstack
1113
from scipy.stats import ranksums
1214
from sklearn import metrics
1315
from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score
@@ -26,7 +28,10 @@ def calculate_niche(
2628
library_key: str | None = None,
2729
table_key: str | None = None,
2830
spatial_key: str = "spatial",
31+
adj_subsets: list[int] | None = None,
32+
aggregation: str = "mean",
2933
spatial_connectivities_key: str = "spatial_connectivities",
34+
spatial_distances_key: str = "spatial_distances",
3035
copy: bool = False,
3136
) -> AnnData | pd.DataFrame:
3237
"""Calculate niches (spatial clusters) based on a user-defined method in 'flavor'.
@@ -112,6 +117,38 @@ def calculate_niche(
112117
else:
113118
adata.layers["utag"] = new_feature_matrix
114119

120+
elif flavor == "cellcharter":
121+
adj_matrix_subsets = []
122+
if isinstance(adj_subsets, list):
123+
for k in adj_subsets:
124+
adj_matrix_subsets.append(
125+
_get_adj_matrix_subsets(
126+
adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k
127+
)
128+
)
129+
if aggregation == "mean":
130+
inner_products = [adata.X.dot(adj_subset) for adj_subset in adj_matrix_subsets]
131+
elif aggregation == "variance":
132+
inner_products = [
133+
_aggregate_var(matrix, adata.obsp[spatial_connectivities_key], adata) for matrix in inner_products
134+
]
135+
else:
136+
raise ValueError(
137+
f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'."
138+
)
139+
concatenated_matrix = vstack(inner_products)
140+
if copy:
141+
return concatenated_matrix
142+
else:
143+
if is_sdata:
144+
sdata.tables[f"{flavor}_niche"] = ad.AnnData(concatenated_matrix)
145+
else:
146+
adata.obsm[f"{flavor}_niche"] = concatenated_matrix
147+
else:
148+
raise ValueError(
149+
"Flavor 'cellcharter' requires list of neighbors to build adjacency matrices. Please provide a list of k_neighbors for 'adj_subsets'."
150+
)
151+
115152

116153
def _calculate_neighborhood_profile(
117154
adata: AnnData | SpatialData,
@@ -164,13 +201,43 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) ->
164201
return adjacency_matrix @ adata.X
165202

166203

204+
def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k_neighbors: int) -> csr_matrix:
205+
# Convert the distance matrix to a dense format for easier manipulation
206+
dist_dense = distances.todense()
207+
208+
# Find the indices of the k closest neighbors for each row
209+
closest_neighbors_indices = np.argsort(dist_dense, axis=1)[:, :k_neighbors]
210+
211+
# Initialize lists to collect data for the new sparse matrix
212+
rows = []
213+
cols = []
214+
data = []
215+
216+
# Iterate over each row to construct the new adjacency matrix
217+
for row in range(dist_dense.shape[0]):
218+
for col in closest_neighbors_indices[row].flat:
219+
rows.append(row)
220+
cols.append(col)
221+
data.append(connectivities[row, col])
222+
223+
# Create the new sparse matrix with the reduced neighbors
224+
new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape)
225+
226+
return new_adj_matrix
227+
228+
167229
def _df_to_adata(df: pd.DataFrame) -> AnnData:
168230
df.index = df.index.map(str)
169231
adata = AnnData(X=df)
170232
adata.obs.index = df.index
171233
return adata
172234

173235

236+
def _aggregate_var(product: csr_matrix, connectivities: csr_matrix, adata: AnnData) -> csr_matrix:
237+
mean_squared = connectivities.dot(adata.X.multiply(adata.X))
238+
return mean_squared - (product.multiply(product))
239+
240+
174241
def pairwise_niche_comparison(
175242
adata: AnnData,
176243
library_key: str,

0 commit comments

Comments
 (0)